[x86] Restructure the parallel bitmath lowering of popcount into

a separate routine, generalize it to work for all the integer vector
sizes, and do general code cleanups.

This dramatically improves lowerings of byte and short element vector
popcount, but more importantly it will make the introduction of the
LUT-approach much cleaner.

The biggest cleanup I've done is to just force the legalizer to do the
bitcasting we need. We run these iteratively now and it makes the code
much simpler IMO. Other changes were minor, and mostly naming and
splitting things up in a way that makes it more clear what is going on.

The other significant change is to use a different final horizontal sum
approach. This is the same number of instructions as the old method, but
shifts left instead of right so that we can clear everything but the
final sum with a single shift right. This seems likely better than
a mask which will usually have to read the mask from memory. It is
certaily fewer u-ops. Also, this will be temporary. This and the LUT
approach share the need of horizontal adds to finish the computation,
and we have more clever approaches than this one that I'll switch over
to.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@238635 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Chandler Carruth 2015-05-30 03:20:55 +00:00
parent 586c0042da
commit 43d1e87d73
3 changed files with 210 additions and 2286 deletions

View File

@ -846,8 +846,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// know to perform better than using the popcnt instructions on each vector
// element. If popcnt isn't supported, always provide the custom version.
if (!Subtarget->hasPOPCNT()) {
setOperationAction(ISD::CTPOP, MVT::v4i32, Custom);
setOperationAction(ISD::CTPOP, MVT::v2i64, Custom);
setOperationAction(ISD::CTPOP, MVT::v4i32, Custom);
setOperationAction(ISD::CTPOP, MVT::v8i16, Custom);
setOperationAction(ISD::CTPOP, MVT::v16i8, Custom);
}
// Custom lower build_vector, vector_shuffle, and extract_vector_elt.
@ -17327,141 +17329,131 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget *Subtarget,
return SDValue();
}
static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
SDNode *Node = Op.getNode();
SDLoc dl(Node);
Op = Op.getOperand(0);
EVT VT = Op.getValueType();
static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
assert((VT.is128BitVector() || VT.is256BitVector()) &&
"CTPOP lowering only implemented for 128/256-bit wide vector types");
unsigned NumElts = VT.getVectorNumElements();
EVT EltVT = VT.getVectorElementType();
unsigned Len = EltVT.getSizeInBits();
int VecSize = VT.getSizeInBits();
int NumElts = VT.getVectorNumElements();
MVT EltVT = VT.getVectorElementType();
int Len = EltVT.getSizeInBits();
// This is the vectorized version of the "best" algorithm from
// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
// with a minor tweak to use a series of adds + shifts instead of vector
// multiplications. Implemented for the v2i64, v4i64, v4i32, v8i32 types:
// multiplications. Implemented for all integer vector types.
//
// v2i64, v4i64, v4i32 => Only profitable w/ popcnt disabled
// v8i32 => Always profitable
//
// FIXME: There a couple of possible improvements:
//
// 1) Support for i8 and i16 vectors (needs measurements if popcnt enabled).
// 2) Use strategies from http://wm.ite.pl/articles/sse-popcount.html
//
assert(EltVT.isInteger() && (Len == 32 || Len == 64) && Len % 8 == 0 &&
"CTPOP not implemented for this vector element type.");
// FIXME: Use strategies from http://wm.ite.pl/articles/sse-popcount.html
// X86 canonicalize ANDs to vXi64, generate the appropriate bitcasts to avoid
// extra legalization.
bool NeedsBitcast = EltVT == MVT::i32;
MVT BitcastVT = VT.is256BitVector() ? MVT::v4i64 : MVT::v2i64;
SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), DL,
EltVT);
SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), DL,
EltVT);
SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), DL,
EltVT);
SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), dl,
EltVT);
SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), dl,
EltVT);
SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), dl,
EltVT);
SDValue V = Op;
// v = v - ((v >> 1) & 0x55555555...)
SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, dl, EltVT));
SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Ones);
SDValue Srl = DAG.getNode(ISD::SRL, dl, VT, Op, OnesV);
if (NeedsBitcast)
Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl);
SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, DL, EltVT));
SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ones);
SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, V, OnesV);
SmallVector<SDValue, 8> Mask55(NumElts, Cst55);
SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask55);
if (NeedsBitcast)
M55 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M55);
SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask55);
SDValue And = DAG.getNode(ISD::AND, DL, Srl.getValueType(), Srl, M55);
SDValue And = DAG.getNode(ISD::AND, dl, Srl.getValueType(), Srl, M55);
if (VT != And.getValueType())
And = DAG.getNode(ISD::BITCAST, dl, VT, And);
SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, Op, And);
V = DAG.getNode(ISD::SUB, DL, VT, V, And);
// v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
SmallVector<SDValue, 8> Mask33(NumElts, Cst33);
SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask33);
SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, dl, EltVT));
SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Twos);
SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask33);
SDValue AndLHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), V, M33);
Srl = DAG.getNode(ISD::SRL, dl, VT, Sub, TwosV);
if (NeedsBitcast) {
Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl);
M33 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M33);
Sub = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Sub);
}
SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, DL, EltVT));
SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Twos);
Srl = DAG.getNode(ISD::SRL, DL, VT, V, TwosV);
SDValue AndRHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), Srl, M33);
SDValue AndRHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Srl, M33);
SDValue AndLHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Sub, M33);
if (VT != AndRHS.getValueType()) {
AndRHS = DAG.getNode(ISD::BITCAST, dl, VT, AndRHS);
AndLHS = DAG.getNode(ISD::BITCAST, dl, VT, AndLHS);
}
SDValue Add = DAG.getNode(ISD::ADD, dl, VT, AndLHS, AndRHS);
V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS);
// v = (v + (v >> 4)) & 0x0F0F0F0F...
SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, dl, EltVT));
SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Fours);
Srl = DAG.getNode(ISD::SRL, dl, VT, Add, FoursV);
Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl);
SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, DL, EltVT));
SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Fours);
Srl = DAG.getNode(ISD::SRL, DL, VT, V, FoursV);
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl);
SmallVector<SDValue, 8> Mask0F(NumElts, Cst0F);
SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask0F);
if (NeedsBitcast) {
Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add);
M0F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M0F);
}
And = DAG.getNode(ISD::AND, dl, M0F.getValueType(), Add, M0F);
if (VT != And.getValueType())
And = DAG.getNode(ISD::BITCAST, dl, VT, And);
SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask0F);
// The algorithm mentioned above uses:
// v = (v * 0x01010101...) >> (Len - 8)
V = DAG.getNode(ISD::AND, DL, M0F.getValueType(), Add, M0F);
// At this point, V contains the byte-wise population count, and we are
// merely doing a horizontal sum if necessary to get the wider element
// counts.
//
// Change it to use vector adds + vector shifts which yield faster results on
// Haswell than using vector integer multiplication.
//
// For i32 elements:
// v = v + (v >> 8)
// v = v + (v >> 16)
//
// For i64 elements:
// v = v + (v >> 8)
// v = v + (v >> 16)
// v = v + (v >> 32)
//
Add = And;
// FIXME: There is a different lowering strategy above for the horizontal sum
// of byte-wise population counts. This one and that one should be merged,
// using the fastest of the two for each size.
MVT ByteVT = MVT::getVectorVT(MVT::i8, VecSize / 8);
MVT ShiftVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
V = DAG.getNode(ISD::BITCAST, DL, ByteVT, V);
SmallVector<SDValue, 8> Csts;
for (unsigned i = 8; i <= Len/2; i *= 2) {
Csts.assign(NumElts, DAG.getConstant(i, dl, EltVT));
SDValue CstsV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Csts);
Srl = DAG.getNode(ISD::SRL, dl, VT, Add, CstsV);
Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl);
Csts.clear();
assert(Len <= 64 && "We don't support element sizes of more than 64 bits!");
assert(isPowerOf2_32(Len) && "Only power of two element sizes supported!");
for (int i = Len; i > 8; i /= 2) {
Csts.assign(VecSize / 64, DAG.getConstant(i / 2, DL, MVT::i64));
SDValue Shl = DAG.getNode(
ISD::SHL, DL, ShiftVT, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V),
DAG.getNode(ISD::BUILD_VECTOR, DL, ShiftVT, Csts));
V = DAG.getNode(ISD::ADD, DL, ByteVT, V,
DAG.getNode(ISD::BITCAST, DL, ByteVT, Shl));
}
// The result is on the least significant 6-bits on i32 and 7-bits on i64.
SDValue Cst3F = DAG.getConstant(APInt(Len, Len == 32 ? 0x3F : 0x7F), dl,
EltVT);
SmallVector<SDValue, 8> Cst3FV(NumElts, Cst3F);
SDValue M3F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Cst3FV);
if (NeedsBitcast) {
Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add);
M3F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M3F);
// The high byte now contains the sum of the element bytes. Shift it right
// (if needed) to make it the low byte.
V = DAG.getNode(ISD::BITCAST, DL, VT, V);
if (Len > 8) {
Csts.assign(NumElts, DAG.getConstant(Len - 8, DL, EltVT));
V = DAG.getNode(ISD::SRL, DL, VT, V,
DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Csts));
}
And = DAG.getNode(ISD::AND, dl, M3F.getValueType(), Add, M3F);
if (VT != And.getValueType())
And = DAG.getNode(ISD::BITCAST, dl, VT, And);
return V;
}
return And;
static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
// FIXME: Need to add AVX-512 support here!
assert((VT.is256BitVector() || VT.is128BitVector()) &&
"Unknown CTPOP type to handle");
SDLoc DL(Op.getNode());
SDValue Op0 = Op.getOperand(0);
if (VT.is256BitVector() && !Subtarget->hasInt256()) {
unsigned NumElems = VT.getVectorNumElements();
// Extract each 128-bit vector, compute pop count and concat the result.
SDValue LHS = Extract128BitVector(Op0, 0, DAG, DL);
SDValue RHS = Extract128BitVector(Op0, NumElems/2, DAG, DL);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT,
LowerVectorCTPOPBitmath(LHS, DL, Subtarget, DAG),
LowerVectorCTPOPBitmath(RHS, DL, Subtarget, DAG));
}
return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG);
}
static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
assert(Op.getValueType().isVector() &&
"We only do custom lowering for vector population count.");
return LowerVectorCTPOP(Op, Subtarget, DAG);
}
static SDValue LowerLOAD_SUB(SDValue Op, SelectionDAG &DAG) {

File diff suppressed because it is too large Load Diff

View File

@ -99,14 +99,13 @@ define <4 x i64> @testv4i64(<4 x i64> %in) {
; AVX2-NEXT: vpaddq %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpbroadcastq {{.*}}(%rip), %ymm1
; AVX2-NEXT: vpand %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsrlq $8, %ymm0, %ymm1
; AVX2-NEXT: vpaddq %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsrlq $16, %ymm0, %ymm1
; AVX2-NEXT: vpaddq %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsrlq $32, %ymm0, %ymm1
; AVX2-NEXT: vpaddq %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpbroadcastq {{.*}}(%rip), %ymm1
; AVX2-NEXT: vpand %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsllq $32, %ymm0, %ymm1
; AVX2-NEXT: vpaddb %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsllq $16, %ymm0, %ymm1
; AVX2-NEXT: vpaddb %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsllq $8, %ymm0, %ymm1
; AVX2-NEXT: vpaddb %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsrlq $56, %ymm0, %ymm0
; AVX2-NEXT: retq
%out = call <4 x i64> @llvm.ctpop.v4i64(<4 x i64> %in)
ret <4 x i64> %out
@ -270,12 +269,11 @@ define <8 x i32> @testv8i32(<8 x i32> %in) {
; AVX2-NEXT: vpaddd %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpbroadcastd {{.*}}(%rip), %ymm1
; AVX2-NEXT: vpand %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsrld $8, %ymm0, %ymm1
; AVX2-NEXT: vpaddd %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsrld $16, %ymm0, %ymm1
; AVX2-NEXT: vpaddd %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpbroadcastd {{.*}}(%rip), %ymm1
; AVX2-NEXT: vpand %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsllq $16, %ymm0, %ymm1
; AVX2-NEXT: vpaddb %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsllq $8, %ymm0, %ymm1
; AVX2-NEXT: vpaddb %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpsrld $24, %ymm0, %ymm0
; AVX2-NEXT: retq
%out = call <8 x i32> @llvm.ctpop.v8i32(<8 x i32> %in)
ret <8 x i32> %out