DAGCombiner: Simplify code a bit, make more transforms work with vectors.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207338 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Benjamin Kramer 2014-04-26 23:09:49 +00:00
parent eadbda3320
commit eb3430cfbd
2 changed files with 72 additions and 58 deletions

View File

@ -644,8 +644,13 @@ static ConstantSDNode *isConstOrConstSplat(SDValue N) {
if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N)) if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
return CN; return CN;
if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
return BV->getConstantSplatValue(); ConstantSDNode *CN = BV->getConstantSplatValue();
// BuildVectors can truncate their operands. Ignore that case here.
if (CN && CN->getValueType(0) == N.getValueType().getScalarType())
return CN;
}
return nullptr; return nullptr;
} }
@ -1957,8 +1962,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue DAGCombiner::visitSDIV(SDNode *N) {
SDValue N0 = N->getOperand(0); SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1); SDValue N1 = N->getOperand(1);
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0.getNode()); ConstantSDNode *N0C = isConstOrConstSplat(N0);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode()); ConstantSDNode *N1C = isConstOrConstSplat(N1);
EVT VT = N->getValueType(0); EVT VT = N->getValueType(0);
// fold vector ops // fold vector ops
@ -1985,25 +1990,15 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
N0, N1); N0, N1);
} }
const APInt *Divisor = nullptr;
if (N1C) {
Divisor = &N1C->getAPIntValue();
} else if (N1.getValueType().isVector() &&
N1->getOpcode() == ISD::BUILD_VECTOR) {
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
if (ConstantSDNode *C = BV->getConstantSplatValue())
Divisor = &C->getAPIntValue();
}
// fold (sdiv X, pow2) -> simple ops after legalize // fold (sdiv X, pow2) -> simple ops after legalize
if (Divisor && !!*Divisor && if (N1C && !N1C->isNullValue() && (N1C->getAPIntValue().isPowerOf2() ||
(Divisor->isPowerOf2() || (-*Divisor).isPowerOf2())) { (-N1C->getAPIntValue()).isPowerOf2())) {
// If dividing by powers of two is cheap, then don't perform the following // If dividing by powers of two is cheap, then don't perform the following
// fold. // fold.
if (TLI.isPow2DivCheap()) if (TLI.isPow2DivCheap())
return SDValue(); return SDValue();
unsigned lg2 = Divisor->countTrailingZeros(); unsigned lg2 = N1C->getAPIntValue().countTrailingZeros();
// Splat the sign bit into the register // Splat the sign bit into the register
SDValue SGN = SDValue SGN =
@ -2025,7 +2020,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
// If we're dividing by a positive value, we're done. Otherwise, we must // If we're dividing by a positive value, we're done. Otherwise, we must
// negate the result. // negate the result.
if (Divisor->isNonNegative()) if (N1C->getAPIntValue().isNonNegative())
return SRA; return SRA;
AddToWorkList(SRA.getNode()); AddToWorkList(SRA.getNode());
@ -2034,7 +2029,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
// if integer divide is expensive and we satisfy the requirements, emit an // if integer divide is expensive and we satisfy the requirements, emit an
// alternate sequence. // alternate sequence.
if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) { if (N1C && !TLI.isIntDivCheap()) {
SDValue Op = BuildSDIV(N); SDValue Op = BuildSDIV(N);
if (Op.getNode()) return Op; if (Op.getNode()) return Op;
} }
@ -2052,8 +2047,8 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
SDValue DAGCombiner::visitUDIV(SDNode *N) { SDValue DAGCombiner::visitUDIV(SDNode *N) {
SDValue N0 = N->getOperand(0); SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1); SDValue N1 = N->getOperand(1);
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0.getNode()); ConstantSDNode *N0C = isConstOrConstSplat(N0);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode()); ConstantSDNode *N1C = isConstOrConstSplat(N1);
EVT VT = N->getValueType(0); EVT VT = N->getValueType(0);
// fold vector ops // fold vector ops
@ -2086,7 +2081,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
} }
} }
// fold (udiv x, c) -> alternate // fold (udiv x, c) -> alternate
if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) { if (N1C && !TLI.isIntDivCheap()) {
SDValue Op = BuildUDIV(N); SDValue Op = BuildUDIV(N);
if (Op.getNode()) return Op; if (Op.getNode()) return Op;
} }
@ -2104,8 +2099,8 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
SDValue DAGCombiner::visitSREM(SDNode *N) { SDValue DAGCombiner::visitSREM(SDNode *N) {
SDValue N0 = N->getOperand(0); SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1); SDValue N1 = N->getOperand(1);
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); ConstantSDNode *N0C = isConstOrConstSplat(N0);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); ConstantSDNode *N1C = isConstOrConstSplat(N1);
EVT VT = N->getValueType(0); EVT VT = N->getValueType(0);
// fold (srem c1, c2) -> c1%c2 // fold (srem c1, c2) -> c1%c2
@ -2146,8 +2141,8 @@ SDValue DAGCombiner::visitSREM(SDNode *N) {
SDValue DAGCombiner::visitUREM(SDNode *N) { SDValue DAGCombiner::visitUREM(SDNode *N) {
SDValue N0 = N->getOperand(0); SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1); SDValue N1 = N->getOperand(1);
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); ConstantSDNode *N0C = isConstOrConstSplat(N0);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); ConstantSDNode *N1C = isConstOrConstSplat(N1);
EVT VT = N->getValueType(0); EVT VT = N->getValueType(0);
// fold (urem c1, c2) -> c1%c2 // fold (urem c1, c2) -> c1%c2
@ -11187,28 +11182,20 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0,
/// multiplying by a magic number. See: /// multiplying by a magic number. See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html> /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
SDValue DAGCombiner::BuildSDIV(SDNode *N) { SDValue DAGCombiner::BuildSDIV(SDNode *N) {
const APInt *Divisor; ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
if (N->getValueType(0).isVector()) { if (!C)
// Handle splat vectors. return SDValue();
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
if (ConstantSDNode *C = BV->getConstantSplatValue())
Divisor = &C->getAPIntValue();
else
return SDValue();
} else {
Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
}
// Avoid division by zero. // Avoid division by zero.
if (!*Divisor) if (!C->getAPIntValue())
return SDValue(); return SDValue();
std::vector<SDNode*> Built; std::vector<SDNode*> Built;
SDValue S = TLI.BuildSDIV(N, *Divisor, DAG, LegalOperations, &Built); SDValue S =
TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built);
for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end(); for (SDNode *N : Built)
ii != ee; ++ii) AddToWorkList(N);
AddToWorkList(*ii);
return S; return S;
} }
@ -11217,28 +11204,20 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) {
/// multiplying by a magic number. See: /// multiplying by a magic number. See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html> /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
SDValue DAGCombiner::BuildUDIV(SDNode *N) { SDValue DAGCombiner::BuildUDIV(SDNode *N) {
const APInt *Divisor; ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
if (N->getValueType(0).isVector()) { if (!C)
// Handle splat vectors. return SDValue();
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
if (ConstantSDNode *C = BV->getConstantSplatValue())
Divisor = &C->getAPIntValue();
else
return SDValue();
} else {
Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
}
// Avoid division by zero. // Avoid division by zero.
if (!*Divisor) if (!C->getAPIntValue())
return SDValue(); return SDValue();
std::vector<SDNode*> Built; std::vector<SDNode*> Built;
SDValue S = TLI.BuildUDIV(N, *Divisor, DAG, LegalOperations, &Built); SDValue S =
TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built);
for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end(); for (SDNode *N : Built)
ii != ee; ++ii) AddToWorkList(N);
AddToWorkList(*ii);
return S; return S;
} }

View File

@ -151,3 +151,38 @@ define <8 x i32> @test9(<8 x i32> %a) {
; AVX: vpsrad $2 ; AVX: vpsrad $2
; AVX: vpadd ; AVX: vpadd
} }
define <8 x i32> @test10(<8 x i32> %a) {
%rem = urem <8 x i32> %a, <i32 7, i32 7, i32 7, i32 7,i32 7, i32 7, i32 7, i32 7>
ret <8 x i32> %rem
; AVX-LABEL: test10:
; AVX: vpermd
; AVX: vpmuludq
; AVX: vshufps $-35
; AVX: vpmuludq
; AVX: vshufps $-35
; AVX: vpsubd
; AVX: vpsrld $1
; AVX: vpadd
; AVX: vpsrld $2
; AVX: vpmulld
}
define <8 x i32> @test11(<8 x i32> %a) {
%rem = srem <8 x i32> %a, <i32 7, i32 7, i32 7, i32 7,i32 7, i32 7, i32 7, i32 7>
ret <8 x i32> %rem
; AVX-LABEL: test11:
; AVX: vpermd
; AVX: vpmuldq
; AVX: vshufps $-35
; AVX: vpmuldq
; AVX: vshufps $-35
; AVX: vpshufd $-40
; AVX: vpadd
; AVX: vpsrld $31
; AVX: vpsrad $2
; AVX: vpadd
; AVX: vpmulld
}