[X86] Refactor the logic to select horizontal adds/subs to a helper function.

This patch moves part of the logic implemented by the target specific
combine rules added at r210477 to a separate helper function.
This should make easier to add more rules for matching AVX/AVX2 horizontal
adds/subs.

This patch also fixes a problem caused by a wrong check performed on indices
of extract_vector_elt dag nodes in input to the scalar adds/subs.

New tests have been added to verify that we correctly check indices of
extract_vector_elt dag nodes when selecting a horizontal operation.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@210644 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Andrea Di Biagio 2014-06-11 07:57:50 +00:00
parent e65c40320b
commit a069e64112
2 changed files with 200 additions and 107 deletions

View File

@ -6057,102 +6057,130 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(ISD::BITCAST, dl, VT, Select);
}
/// \brief Return true if \p N implements a horizontal binop and return the
/// operands for the horizontal binop into V0 and V1.
///
/// This is a helper function of PerformBUILD_VECTORCombine.
/// This function checks that the build_vector \p N in input implements a
/// horizontal operation. Parameter \p Opcode defines the kind of horizontal
/// operation to match.
/// For example, if \p Opcode is equal to ISD::ADD, then this function
/// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode
/// is equal to ISD::SUB, then this function checks if this is a horizontal
/// arithmetic sub.
///
/// This function only analyzes elements of \p N whose indices are
/// in range [BaseIdx, LastIdx).
static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,
unsigned BaseIdx, unsigned LastIdx,
SDValue &V0, SDValue &V1) {
assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!");
assert(N->getValueType(0).isVector() &&
N->getValueType(0).getVectorNumElements() >= LastIdx &&
"Invalid Vector in input!");
bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD);
bool CanFold = true;
unsigned ExpectedVExtractIdx = BaseIdx;
unsigned NumElts = LastIdx - BaseIdx;
// Check if N implements a horizontal binop.
for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) {
SDValue Op = N->getOperand(i + BaseIdx);
CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
if (!CanFold)
break;
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
// Try to match the following pattern:
// (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Op0.getOperand(0) == Op1.getOperand(0) &&
isa<ConstantSDNode>(Op0.getOperand(1)) &&
isa<ConstantSDNode>(Op1.getOperand(1)));
if (!CanFold)
break;
unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();
if (i == 0)
V0 = Op0.getOperand(0);
else if (i * 2 == NumElts) {
V1 = Op0.getOperand(0);
ExpectedVExtractIdx = BaseIdx;
}
SDValue Expected = (i * 2 < NumElts) ? V0 : V1;
if (I0 == ExpectedVExtractIdx)
CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
else if (IsCommutable && I1 == ExpectedVExtractIdx) {
// Try to match the following dag sequence:
// (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
} else
CanFold = false;
ExpectedVExtractIdx += 2;
}
return CanFold;
}
static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
const X86Subtarget *Subtarget) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
unsigned NumElts = VT.getVectorNumElements();
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N);
SDValue InVec0, InVec1;
// Try to match a horizontal ADD or SUB.
if (((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) ||
((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) ||
((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
VT == MVT::v16i16) && Subtarget->hasAVX())) {
unsigned NumOperands = N->getNumOperands();
unsigned Opcode = N->getOperand(0)->getOpcode();
bool isCommutable = false;
bool CanFold = false;
switch (Opcode) {
default : break;
case ISD::ADD :
case ISD::FADD :
isCommutable = true;
// FALL-THROUGH
case ISD::SUB :
case ISD::FSUB :
CanFold = true;
}
// Try to match horizontal ADD/SUB.
if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) {
// Try to match an SSE3 float HADD/HSUB.
if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1);
if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1);
} else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) {
// Try to match an SSSE3 integer HADD/HSUB.
if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1);
if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1);
}
// Verify that operands have the same opcode; also, the opcode can only
// be either of: ADD, FADD, SUB, FSUB.
SDValue InVec0, InVec1;
for (unsigned i = 0, e = NumOperands; i != e && CanFold; ++i) {
SDValue Op = N->getOperand(i);
CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
VT == MVT::v16i16) && Subtarget->hasAVX()) {
unsigned X86Opcode;
if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
X86Opcode = X86ISD::HADD;
else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
X86Opcode = X86ISD::HSUB;
else if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
X86Opcode = X86ISD::FHADD;
else if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
X86Opcode = X86ISD::FHSUB;
else
return SDValue();
if (!CanFold)
break;
// Convert this build_vector into two horizontal add/sub followed by
// a concat vector.
SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, DL);
SDValue InVec0_HI = Extract128BitVector(InVec0, NumElts/2, DAG, DL);
SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, DL);
SDValue InVec1_HI = Extract128BitVector(InVec1, NumElts/2, DAG, DL);
EVT NewVT = InVec0_LO.getValueType();
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
// Try to match the following pattern:
// (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Op0.getOperand(0) == Op1.getOperand(0) &&
isa<ConstantSDNode>(Op0.getOperand(1)) &&
isa<ConstantSDNode>(Op1.getOperand(1)));
if (!CanFold)
break;
unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();
unsigned ExpectedIndex = (i * 2) % NumOperands;
if (i == 0)
InVec0 = Op0.getOperand(0);
else if (i * 2 == NumOperands)
InVec1 = Op0.getOperand(0);
SDValue Expected = (i * 2 < NumOperands) ? InVec0 : InVec1;
if (I0 == ExpectedIndex)
CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
else if (isCommutable && I1 == ExpectedIndex) {
// Try to see if we can match the following dag sequence:
// (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
}
}
if (CanFold) {
unsigned NewOpcode;
switch (Opcode) {
default : llvm_unreachable("Unexpected opcode found!");
case ISD::ADD : NewOpcode = X86ISD::HADD; break;
case ISD::FADD : NewOpcode = X86ISD::FHADD; break;
case ISD::SUB : NewOpcode = X86ISD::HSUB; break;
case ISD::FSUB : NewOpcode = X86ISD::FHSUB; break;
}
if (VT.is256BitVector()) {
SDLoc dl(N);
// Convert this sequence into two horizontal add/sub followed
// by a concat vector.
SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, dl);
SDValue InVec0_HI =
Extract128BitVector(InVec0, NumOperands/2, DAG, dl);
SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, dl);
SDValue InVec1_HI =
Extract128BitVector(InVec1, NumOperands/2, DAG, dl);
EVT NewVT = InVec0_LO.getValueType();
SDValue LO = DAG.getNode(NewOpcode, dl, NewVT, InVec0_LO, InVec0_HI);
SDValue HI = DAG.getNode(NewOpcode, dl, NewVT, InVec1_LO, InVec1_HI);
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, LO, HI);
}
return DAG.getNode(NewOpcode, SDLoc(N), VT, InVec0, InVec1);
}
SDValue LO = DAG.getNode(X86Opcode, DL, NewVT, InVec0_LO, InVec0_HI);
SDValue HI = DAG.getNode(X86Opcode, DL, NewVT, InVec1_LO, InVec1_HI);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI);
}
return SDValue();

View File

@ -86,12 +86,12 @@ define <4 x float> @hsub_ps_test2(<4 x float> %A, <4 x float> %B) {
%vecext3 = extractelement <4 x float> %A, i32 1
%sub4 = fsub float %vecext2, %vecext3
%vecinit5 = insertelement <4 x float> %vecinit, float %sub4, i32 0
%vecext6 = extractelement <4 x float> %B, i32 3
%vecext7 = extractelement <4 x float> %B, i32 2
%vecext6 = extractelement <4 x float> %B, i32 2
%vecext7 = extractelement <4 x float> %B, i32 3
%sub8 = fsub float %vecext6, %vecext7
%vecinit9 = insertelement <4 x float> %vecinit5, float %sub8, i32 3
%vecext10 = extractelement <4 x float> %B, i32 1
%vecext11 = extractelement <4 x float> %B, i32 0
%vecext10 = extractelement <4 x float> %B, i32 0
%vecext11 = extractelement <4 x float> %B, i32 1
%sub12 = fsub float %vecext10, %vecext11
%vecinit13 = insertelement <4 x float> %vecinit9, float %sub12, i32 2
ret <4 x float> %vecinit13
@ -137,12 +137,12 @@ define <4 x i32> @phadd_d_test2(<4 x i32> %A, <4 x i32> %B) {
%vecext3 = extractelement <4 x i32> %A, i32 1
%add4 = add i32 %vecext2, %vecext3
%vecinit5 = insertelement <4 x i32> %vecinit, i32 %add4, i32 0
%vecext6 = extractelement <4 x i32> %B, i32 2
%vecext7 = extractelement <4 x i32> %B, i32 3
%vecext6 = extractelement <4 x i32> %B, i32 3
%vecext7 = extractelement <4 x i32> %B, i32 2
%add8 = add i32 %vecext6, %vecext7
%vecinit9 = insertelement <4 x i32> %vecinit5, i32 %add8, i32 3
%vecext10 = extractelement <4 x i32> %B, i32 0
%vecext11 = extractelement <4 x i32> %B, i32 1
%vecext10 = extractelement <4 x i32> %B, i32 1
%vecext11 = extractelement <4 x i32> %B, i32 0
%add12 = add i32 %vecext10, %vecext11
%vecinit13 = insertelement <4 x i32> %vecinit9, i32 %add12, i32 2
ret <4 x i32> %vecinit13
@ -191,12 +191,12 @@ define <4 x i32> @phsub_d_test2(<4 x i32> %A, <4 x i32> %B) {
%vecext3 = extractelement <4 x i32> %A, i32 1
%sub4 = sub i32 %vecext2, %vecext3
%vecinit5 = insertelement <4 x i32> %vecinit, i32 %sub4, i32 0
%vecext6 = extractelement <4 x i32> %B, i32 3
%vecext7 = extractelement <4 x i32> %B, i32 2
%vecext6 = extractelement <4 x i32> %B, i32 2
%vecext7 = extractelement <4 x i32> %B, i32 3
%sub8 = sub i32 %vecext6, %vecext7
%vecinit9 = insertelement <4 x i32> %vecinit5, i32 %sub8, i32 3
%vecext10 = extractelement <4 x i32> %B, i32 1
%vecext11 = extractelement <4 x i32> %B, i32 0
%vecext10 = extractelement <4 x i32> %B, i32 0
%vecext11 = extractelement <4 x i32> %B, i32 1
%sub12 = sub i32 %vecext10, %vecext11
%vecinit13 = insertelement <4 x i32> %vecinit9, i32 %sub12, i32 2
ret <4 x i32> %vecinit13
@ -258,14 +258,14 @@ define <2 x double> @hsub_pd_test1(<2 x double> %A, <2 x double> %B) {
define <2 x double> @hsub_pd_test2(<2 x double> %A, <2 x double> %B) {
%vecext = extractelement <2 x double> %A, i32 1
%vecext1 = extractelement <2 x double> %A, i32 0
%vecext = extractelement <2 x double> %B, i32 0
%vecext1 = extractelement <2 x double> %B, i32 1
%sub = fsub double %vecext, %vecext1
%vecinit = insertelement <2 x double> undef, double %sub, i32 0
%vecext2 = extractelement <2 x double> %B, i32 1
%vecext3 = extractelement <2 x double> %B, i32 0
%vecinit = insertelement <2 x double> undef, double %sub, i32 1
%vecext2 = extractelement <2 x double> %A, i32 0
%vecext3 = extractelement <2 x double> %A, i32 1
%sub2 = fsub double %vecext2, %vecext3
%vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 1
%vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 0
ret <2 x double> %vecinit2
}
; CHECK-LABEL: hsub_pd_test2
@ -458,3 +458,68 @@ define <16 x i16> @avx2_vphadd_w_test(<16 x i16> %a, <16 x i16> %b) {
; CHECK: ret
; Verify that we don't select horizontal subs in the following functions.
define <4 x i32> @not_a_hsub_1(<4 x i32> %A, <4 x i32> %B) {
%vecext = extractelement <4 x i32> %A, i32 0
%vecext1 = extractelement <4 x i32> %A, i32 1
%sub = sub i32 %vecext, %vecext1
%vecinit = insertelement <4 x i32> undef, i32 %sub, i32 0
%vecext2 = extractelement <4 x i32> %A, i32 2
%vecext3 = extractelement <4 x i32> %A, i32 3
%sub4 = sub i32 %vecext2, %vecext3
%vecinit5 = insertelement <4 x i32> %vecinit, i32 %sub4, i32 1
%vecext6 = extractelement <4 x i32> %B, i32 1
%vecext7 = extractelement <4 x i32> %B, i32 0
%sub8 = sub i32 %vecext6, %vecext7
%vecinit9 = insertelement <4 x i32> %vecinit5, i32 %sub8, i32 2
%vecext10 = extractelement <4 x i32> %B, i32 3
%vecext11 = extractelement <4 x i32> %B, i32 2
%sub12 = sub i32 %vecext10, %vecext11
%vecinit13 = insertelement <4 x i32> %vecinit9, i32 %sub12, i32 3
ret <4 x i32> %vecinit13
}
; CHECK-LABEL: not_a_hsub_1
; CHECK-NOT: phsubd
; CHECK: ret
define <4 x float> @not_a_hsub_2(<4 x float> %A, <4 x float> %B) {
%vecext = extractelement <4 x float> %A, i32 2
%vecext1 = extractelement <4 x float> %A, i32 3
%sub = fsub float %vecext, %vecext1
%vecinit = insertelement <4 x float> undef, float %sub, i32 1
%vecext2 = extractelement <4 x float> %A, i32 0
%vecext3 = extractelement <4 x float> %A, i32 1
%sub4 = fsub float %vecext2, %vecext3
%vecinit5 = insertelement <4 x float> %vecinit, float %sub4, i32 0
%vecext6 = extractelement <4 x float> %B, i32 3
%vecext7 = extractelement <4 x float> %B, i32 2
%sub8 = fsub float %vecext6, %vecext7
%vecinit9 = insertelement <4 x float> %vecinit5, float %sub8, i32 3
%vecext10 = extractelement <4 x float> %B, i32 0
%vecext11 = extractelement <4 x float> %B, i32 1
%sub12 = fsub float %vecext10, %vecext11
%vecinit13 = insertelement <4 x float> %vecinit9, float %sub12, i32 2
ret <4 x float> %vecinit13
}
; CHECK-LABEL: not_a_hsub_2
; CHECK-NOT: hsubps
; CHECK: ret
define <2 x double> @not_a_hsub_3(<2 x double> %A, <2 x double> %B) {
%vecext = extractelement <2 x double> %B, i32 0
%vecext1 = extractelement <2 x double> %B, i32 1
%sub = fsub double %vecext, %vecext1
%vecinit = insertelement <2 x double> undef, double %sub, i32 1
%vecext2 = extractelement <2 x double> %A, i32 1
%vecext3 = extractelement <2 x double> %A, i32 0
%sub2 = fsub double %vecext2, %vecext3
%vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 0
ret <2 x double> %vecinit2
}
; CHECK-LABEL: not_a_hsub_3
; CHECK-NOT: hsubpd
; CHECK: ret