[x86] Restore the bitcasts I removed when refactoring this to avoid

shifting vectors of bytes as x86 doesn't have direct support for that.

This removes a bunch of redundant masking in the generated code for SSE2
and SSE3.

In order to avoid the really significant code size growth this would
have triggered, I also factored the completely repeatative logic for
shifting and masking into two lambdas which in turn makes all of this
much easier to read IMO.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@238637 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Chandler Carruth 2015-05-30 04:05:11 +00:00
parent 828f5b807c
commit d8018eeac9
2 changed files with 43 additions and 53 deletions

View File

@ -17479,7 +17479,6 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
"Only 128-bit vector bitmath lowering supported.");
int VecSize = VT.getSizeInBits();
int NumElts = VT.getVectorNumElements();
MVT EltVT = VT.getVectorElementType();
int Len = EltVT.getSizeInBits();
@ -17490,48 +17489,52 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
// this when we don't have SSSE3 which allows a LUT-based lowering that is
// much faster, even faster than using native popcnt instructions.
SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), DL,
EltVT);
SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), DL,
EltVT);
SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), DL,
EltVT);
auto GetShift = [&](unsigned OpCode, SDValue V, int Shifter) {
MVT VT = V.getSimpleValueType();
SmallVector<SDValue, 32> Shifters(
VT.getVectorNumElements(),
DAG.getConstant(Shifter, DL, VT.getVectorElementType()));
return DAG.getNode(OpCode, DL, VT, V,
DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Shifters));
};
auto GetMask = [&](SDValue V, APInt Mask) {
MVT VT = V.getSimpleValueType();
SmallVector<SDValue, 32> Masks(
VT.getVectorNumElements(),
DAG.getConstant(Mask, DL, VT.getVectorElementType()));
return DAG.getNode(ISD::AND, DL, VT, V,
DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Masks));
};
// We don't want to incur the implicit masks required to SRL vNi8 vectors on
// x86, so set the SRL type to have elements at least i16 wide. This is
// correct because all of our SRLs are followed immediately by a mask anyways
// that handles any bits that sneak into the high bits of the byte elements.
MVT SrlVT = Len > 8 ? VT : MVT::getVectorVT(MVT::i16, VecSize / 16);
SDValue V = Op;
// v = v - ((v >> 1) & 0x55555555...)
SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, DL, EltVT));
SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ones);
SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, V, OnesV);
SmallVector<SDValue, 8> Mask55(NumElts, Cst55);
SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask55);
SDValue And = DAG.getNode(ISD::AND, DL, Srl.getValueType(), Srl, M55);
SDValue Srl = DAG.getNode(
ISD::BITCAST, DL, VT,
GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 1));
SDValue And = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x55)));
V = DAG.getNode(ISD::SUB, DL, VT, V, And);
// v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
SmallVector<SDValue, 8> Mask33(NumElts, Cst33);
SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask33);
SDValue AndLHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), V, M33);
SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, DL, EltVT));
SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Twos);
Srl = DAG.getNode(ISD::SRL, DL, VT, V, TwosV);
SDValue AndRHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), Srl, M33);
SDValue AndLHS = GetMask(V, APInt::getSplat(Len, APInt(8, 0x33)));
Srl = DAG.getNode(
ISD::BITCAST, DL, VT,
GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 2));
SDValue AndRHS = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x33)));
V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS);
// v = (v + (v >> 4)) & 0x0F0F0F0F...
SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, DL, EltVT));
SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Fours);
Srl = DAG.getNode(ISD::SRL, DL, VT, V, FoursV);
Srl = DAG.getNode(
ISD::BITCAST, DL, VT,
GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 4));
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl);
SmallVector<SDValue, 8> Mask0F(NumElts, Cst0F);
SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask0F);
V = DAG.getNode(ISD::AND, DL, M0F.getValueType(), Add, M0F);
V = GetMask(Add, APInt::getSplat(Len, APInt(8, 0x0F)));
// At this point, V contains the byte-wise population count, and we are
// merely doing a horizontal sum if necessary to get the wider element
@ -17543,26 +17546,21 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
MVT ByteVT = MVT::getVectorVT(MVT::i8, VecSize / 8);
MVT ShiftVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
V = DAG.getNode(ISD::BITCAST, DL, ByteVT, V);
SmallVector<SDValue, 8> Csts;
assert(Len <= 64 && "We don't support element sizes of more than 64 bits!");
assert(isPowerOf2_32(Len) && "Only power of two element sizes supported!");
for (int i = Len; i > 8; i /= 2) {
Csts.assign(VecSize / 64, DAG.getConstant(i / 2, DL, MVT::i64));
SDValue Shl = DAG.getNode(
ISD::SHL, DL, ShiftVT, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V),
DAG.getNode(ISD::BUILD_VECTOR, DL, ShiftVT, Csts));
V = DAG.getNode(ISD::ADD, DL, ByteVT, V,
DAG.getNode(ISD::BITCAST, DL, ByteVT, Shl));
ISD::BITCAST, DL, ByteVT,
GetShift(ISD::SHL, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V), i / 2));
V = DAG.getNode(ISD::ADD, DL, ByteVT, V, Shl);
}
// The high byte now contains the sum of the element bytes. Shift it right
// (if needed) to make it the low byte.
V = DAG.getNode(ISD::BITCAST, DL, VT, V);
if (Len > 8) {
Csts.assign(NumElts, DAG.getConstant(Len - 8, DL, EltVT));
V = DAG.getNode(ISD::SRL, DL, VT, V,
DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Csts));
}
if (Len > 8)
V = GetShift(ISD::SRL, V, Len - 8);
return V;
}

View File

@ -339,21 +339,17 @@ define <16 x i8> @testv16i8(<16 x i8> %in) {
; SSE2-NEXT: movdqa %xmm0, %xmm1
; SSE2-NEXT: psrlw $1, %xmm1
; SSE2-NEXT: pand {{.*}}(%rip), %xmm1
; SSE2-NEXT: pand {{.*}}(%rip), %xmm1
; SSE2-NEXT: psubb %xmm1, %xmm0
; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51]
; SSE2-NEXT: movdqa %xmm0, %xmm2
; SSE2-NEXT: pand %xmm1, %xmm2
; SSE2-NEXT: psrlw $2, %xmm0
; SSE2-NEXT: pand {{.*}}(%rip), %xmm0
; SSE2-NEXT: pand %xmm1, %xmm0
; SSE2-NEXT: paddb %xmm2, %xmm0
; SSE2-NEXT: movdqa %xmm0, %xmm1
; SSE2-NEXT: psrlw $4, %xmm1
; SSE2-NEXT: movdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15]
; SSE2-NEXT: pand %xmm2, %xmm1
; SSE2-NEXT: paddb %xmm0, %xmm1
; SSE2-NEXT: pand %xmm2, %xmm1
; SSE2-NEXT: pand {{.*}}(%rip), %xmm1
; SSE2-NEXT: movdqa %xmm1, %xmm0
; SSE2-NEXT: retq
;
@ -362,21 +358,17 @@ define <16 x i8> @testv16i8(<16 x i8> %in) {
; SSE3-NEXT: movdqa %xmm0, %xmm1
; SSE3-NEXT: psrlw $1, %xmm1
; SSE3-NEXT: pand {{.*}}(%rip), %xmm1
; SSE3-NEXT: pand {{.*}}(%rip), %xmm1
; SSE3-NEXT: psubb %xmm1, %xmm0
; SSE3-NEXT: movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51]
; SSE3-NEXT: movdqa %xmm0, %xmm2
; SSE3-NEXT: pand %xmm1, %xmm2
; SSE3-NEXT: psrlw $2, %xmm0
; SSE3-NEXT: pand {{.*}}(%rip), %xmm0
; SSE3-NEXT: pand %xmm1, %xmm0
; SSE3-NEXT: paddb %xmm2, %xmm0
; SSE3-NEXT: movdqa %xmm0, %xmm1
; SSE3-NEXT: psrlw $4, %xmm1
; SSE3-NEXT: movdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15]
; SSE3-NEXT: pand %xmm2, %xmm1
; SSE3-NEXT: paddb %xmm0, %xmm1
; SSE3-NEXT: pand %xmm2, %xmm1
; SSE3-NEXT: pand {{.*}}(%rip), %xmm1
; SSE3-NEXT: movdqa %xmm1, %xmm0
; SSE3-NEXT: retq
;