[x86] Split out the horizontal byte sum lowering component of the LUT

lowering into a helper function.

NFC.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@238650 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Chandler Carruth 2015-05-30 09:46:16 +00:00
parent be43b88fae
commit da8bb20158

View File

@ -17290,11 +17290,124 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget *Subtarget,
return SDValue();
}
/// Compute the horizontal sum of bytes in V for the elements of VT.
///
/// Requires V to be a byte vector and VT to be an integer vector type with
/// wider elements than V's type. The width of the elements of VT determines
/// how many bytes of V are summed horizontally to produce each element of the
/// result.
static SDValue LowerHorizontalByteSum(SDValue V, MVT VT,
const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
SDLoc DL(V);
MVT ByteVecVT = V.getSimpleValueType();
MVT EltVT = VT.getVectorElementType();
int NumElts = VT.getVectorNumElements();
assert(ByteVecVT.getVectorElementType() == MVT::i8 &&
"Expected value to have byte element type.");
assert(EltVT != MVT::i8 &&
"Horizontal byte sum only makes sense for wider elements!");
unsigned VecSize = VT.getSizeInBits();
assert(ByteVecVT.getSizeInBits() == VecSize && "Cannot change vector size!");
// PSADBW instruction horizontally add all bytes and leave the result in i64
// chunks, thus directly computes the pop count for v2i64 and v4i64.
if (EltVT == MVT::i64) {
SDValue Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL);
V = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT, V, Zeros);
return DAG.getBitcast(VT, V);
}
if (EltVT == MVT::i32) {
// We unpack the low half and high half into i32s interleaved with zeros so
// that we can use PSADBW to horizontally sum them. The most useful part of
// this is that it lines up the results of two PSADBW instructions to be
// two v2i64 vectors which concatenated are the 4 population counts. We can
// then use PACKUSWB to shrink and concatenate them into a v4i32 again.
SDValue Zeros = getZeroVector(VT, Subtarget, DAG, DL);
SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, V, Zeros);
SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, V, Zeros);
// Do the horizontal sums into two v2i64s.
Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL);
Low = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT,
DAG.getBitcast(ByteVecVT, Low), Zeros);
High = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT,
DAG.getBitcast(ByteVecVT, High), Zeros);
// Merge them together.
MVT ShortVecVT = MVT::getVectorVT(MVT::i16, VecSize / 16);
V = DAG.getNode(X86ISD::PACKUS, DL, ByteVecVT,
DAG.getBitcast(ShortVecVT, Low),
DAG.getBitcast(ShortVecVT, High));
return DAG.getBitcast(VT, V);
}
// To obtain pop count for each i16 element, shuffle the byte pop count to get
// even and odd elements into distinct vectors, add them and zero-extend each
// i8 elemento into i16, i.e.:
//
// B -> pop count per i8
// W -> pop count per i16
//
// Y = shuffle B, undef <0, 2, ...>
// Z = shuffle B, undef <1, 3, ...>
// W = zext <... x i8> to <... x i16> (Y + Z)
//
// Use a byte shuffle mask that matches PSHUFB.
//
assert(EltVT == MVT::i16 && "Unknown how to handle type");
SDValue Undef = DAG.getUNDEF(ByteVecVT);
SmallVector<int, 32> MaskA, MaskB;
// We can't use PSHUFB across lanes, so do the shuffle and sum inside each
// 128-bit lane, and then collapse the result.
int NumLanes = VecSize / 128;
assert(VecSize % 128 == 0 && "Must have 16-byte multiple vectors!");
for (int i = 0; i < NumLanes; ++i) {
for (int j = 0; j < 8; ++j) {
MaskA.push_back(i * 16 + j * 2);
MaskB.push_back(i * 16 + (j * 2) + 1);
}
MaskA.append((size_t)8, -1);
MaskB.append((size_t)8, -1);
}
SDValue ShuffA = DAG.getVectorShuffle(ByteVecVT, DL, V, Undef, MaskA);
SDValue ShuffB = DAG.getVectorShuffle(ByteVecVT, DL, V, Undef, MaskB);
V = DAG.getNode(ISD::ADD, DL, ByteVecVT, ShuffA, ShuffB);
SmallVector<int, 4> Mask;
for (int i = 0; i < NumLanes; ++i)
Mask.push_back(2 * i);
Mask.append((size_t)NumLanes, -1);
int NumI64Elts = VecSize / 64;
MVT VecI64VT = MVT::getVectorVT(MVT::i64, NumI64Elts);
V = DAG.getBitcast(VecI64VT, V);
V = DAG.getVectorShuffle(VecI64VT, DL, V, DAG.getUNDEF(VecI64VT), Mask);
V = DAG.getBitcast(ByteVecVT, V);
// Zero extend i8s into i16 elts
SmallVector<int, 16> ZExtInRegMask;
for (int i = 0; i < NumElts; ++i) {
ZExtInRegMask.push_back(i);
ZExtInRegMask.push_back(2 * NumElts);
}
return DAG.getBitcast(
VT, DAG.getVectorShuffle(ByteVecVT, DL, V,
getZeroVector(ByteVecVT, Subtarget, DAG, DL),
ZExtInRegMask));
}
static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
EVT VT = Op.getValueType();
MVT EltVT = VT.getVectorElementType().getSimpleVT();
const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
MVT EltVT = VT.getVectorElementType();
unsigned VecSize = VT.getSizeInBits();
// Implement a lookup table in register by using an algorithm based on:
@ -17347,98 +17460,7 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
if (EltVT == MVT::i8)
return PopCnt;
// PSADBW instruction horizontally add all bytes and leave the result in i64
// chunks, thus directly computes the pop count for v2i64 and v4i64.
if (EltVT == MVT::i64) {
SDValue Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL);
PopCnt = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT, PopCnt, Zeros);
return DAG.getBitcast(VT, PopCnt);
}
int NumI64Elts = VecSize / 64;
MVT VecI64VT = MVT::getVectorVT(MVT::i64, NumI64Elts);
if (EltVT == MVT::i32) {
// We unpack the low half and high half into i32s interleaved with zeros so
// that we can use PSADBW to horizontally sum them. The most useful part of
// this is that it lines up the results of two PSADBW instructions to be
// two v2i64 vectors which concatenated are the 4 population counts. We can
// then use PACKUSWB to shrink and concatenate them into a v4i32 again.
SDValue Zeros = getZeroVector(VT, Subtarget, DAG, DL);
SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, PopCnt, Zeros);
SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, PopCnt, Zeros);
// Do the horizontal sums into two v2i64s.
Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL);
Low = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT,
DAG.getBitcast(ByteVecVT, Low), Zeros);
High = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT,
DAG.getBitcast(ByteVecVT, High), Zeros);
// Merge them together.
MVT ShortVecVT = MVT::getVectorVT(MVT::i16, VecSize / 16);
PopCnt = DAG.getNode(X86ISD::PACKUS, DL, ByteVecVT,
DAG.getBitcast(ShortVecVT, Low),
DAG.getBitcast(ShortVecVT, High));
return DAG.getBitcast(VT, PopCnt);
}
// To obtain pop count for each i16 element, shuffle the byte pop count to get
// even and odd elements into distinct vectors, add them and zero-extend each
// i8 elemento into i16, i.e.:
//
// B -> pop count per i8
// W -> pop count per i16
//
// Y = shuffle B, undef <0, 2, ...>
// Z = shuffle B, undef <1, 3, ...>
// W = zext <... x i8> to <... x i16> (Y + Z)
//
// Use a byte shuffle mask that matches PSHUFB.
//
assert(EltVT == MVT::i16 && "Unknown how to handle type");
SDValue Undef = DAG.getUNDEF(ByteVecVT);
SmallVector<int, 32> MaskA, MaskB;
// We can't use PSHUFB across lanes, so do the shuffle and sum inside each
// 128-bit lane, and then collapse the result.
int NumLanes = NumByteElts / 16;
assert(NumByteElts % 16 == 0 && "Must have 16-byte multiple vectors!");
for (int i = 0; i < NumLanes; ++i) {
for (int j = 0; j < 8; ++j) {
MaskA.push_back(i * 16 + j * 2);
MaskB.push_back(i * 16 + (j * 2) + 1);
}
MaskA.append((size_t)8, -1);
MaskB.append((size_t)8, -1);
}
SDValue ShuffA = DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, Undef, MaskA);
SDValue ShuffB = DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, Undef, MaskB);
PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, ShuffA, ShuffB);
SmallVector<int, 4> Mask;
for (int i = 0; i < NumLanes; ++i)
Mask.push_back(2 * i);
Mask.append((size_t)NumLanes, -1);
PopCnt = DAG.getBitcast(VecI64VT, PopCnt);
PopCnt =
DAG.getVectorShuffle(VecI64VT, DL, PopCnt, DAG.getUNDEF(VecI64VT), Mask);
PopCnt = DAG.getBitcast(ByteVecVT, PopCnt);
// Zero extend i8s into i16 elts
SmallVector<int, 16> ZExtInRegMask;
for (int i = 0; i < NumByteElts / 2; ++i) {
ZExtInRegMask.push_back(i);
ZExtInRegMask.push_back(NumByteElts);
}
return DAG.getBitcast(
VT, DAG.getVectorShuffle(ByteVecVT, DL, PopCnt,
getZeroVector(ByteVecVT, Subtarget, DAG, DL),
ZExtInRegMask));
return LowerHorizontalByteSum(PopCnt, VT, Subtarget, DAG);
}
static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,