DAGCombiner: Simplify code a bit, make more transforms work with vectors.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207338 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Benjamin Kramer 2014-04-26 23:09:49 +00:00
parent eadbda3320
commit eb3430cfbd
2 changed files with 72 additions and 58 deletions

View File

@ -644,8 +644,13 @@ static ConstantSDNode *isConstOrConstSplat(SDValue N) {
if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
return CN;
if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N))
return BV->getConstantSplatValue();
if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
ConstantSDNode *CN = BV->getConstantSplatValue();
// BuildVectors can truncate their operands. Ignore that case here.
if (CN && CN->getValueType(0) == N.getValueType().getScalarType())
return CN;
}
return nullptr;
}
@ -1957,8 +1962,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
SDValue DAGCombiner::visitSDIV(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0.getNode());
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
ConstantSDNode *N0C = isConstOrConstSplat(N0);
ConstantSDNode *N1C = isConstOrConstSplat(N1);
EVT VT = N->getValueType(0);
// fold vector ops
@ -1985,25 +1990,15 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
N0, N1);
}
const APInt *Divisor = nullptr;
if (N1C) {
Divisor = &N1C->getAPIntValue();
} else if (N1.getValueType().isVector() &&
N1->getOpcode() == ISD::BUILD_VECTOR) {
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
if (ConstantSDNode *C = BV->getConstantSplatValue())
Divisor = &C->getAPIntValue();
}
// fold (sdiv X, pow2) -> simple ops after legalize
if (Divisor && !!*Divisor &&
(Divisor->isPowerOf2() || (-*Divisor).isPowerOf2())) {
if (N1C && !N1C->isNullValue() && (N1C->getAPIntValue().isPowerOf2() ||
(-N1C->getAPIntValue()).isPowerOf2())) {
// If dividing by powers of two is cheap, then don't perform the following
// fold.
if (TLI.isPow2DivCheap())
return SDValue();
unsigned lg2 = Divisor->countTrailingZeros();
unsigned lg2 = N1C->getAPIntValue().countTrailingZeros();
// Splat the sign bit into the register
SDValue SGN =
@ -2025,7 +2020,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
// If we're dividing by a positive value, we're done. Otherwise, we must
// negate the result.
if (Divisor->isNonNegative())
if (N1C->getAPIntValue().isNonNegative())
return SRA;
AddToWorkList(SRA.getNode());
@ -2034,7 +2029,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
// if integer divide is expensive and we satisfy the requirements, emit an
// alternate sequence.
if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) {
if (N1C && !TLI.isIntDivCheap()) {
SDValue Op = BuildSDIV(N);
if (Op.getNode()) return Op;
}
@ -2052,8 +2047,8 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
SDValue DAGCombiner::visitUDIV(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0.getNode());
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
ConstantSDNode *N0C = isConstOrConstSplat(N0);
ConstantSDNode *N1C = isConstOrConstSplat(N1);
EVT VT = N->getValueType(0);
// fold vector ops
@ -2086,7 +2081,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
}
}
// fold (udiv x, c) -> alternate
if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) {
if (N1C && !TLI.isIntDivCheap()) {
SDValue Op = BuildUDIV(N);
if (Op.getNode()) return Op;
}
@ -2104,8 +2099,8 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
SDValue DAGCombiner::visitSREM(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
ConstantSDNode *N0C = isConstOrConstSplat(N0);
ConstantSDNode *N1C = isConstOrConstSplat(N1);
EVT VT = N->getValueType(0);
// fold (srem c1, c2) -> c1%c2
@ -2146,8 +2141,8 @@ SDValue DAGCombiner::visitSREM(SDNode *N) {
SDValue DAGCombiner::visitUREM(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
ConstantSDNode *N0C = isConstOrConstSplat(N0);
ConstantSDNode *N1C = isConstOrConstSplat(N1);
EVT VT = N->getValueType(0);
// fold (urem c1, c2) -> c1%c2
@ -11187,28 +11182,20 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0,
/// multiplying by a magic number. See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
SDValue DAGCombiner::BuildSDIV(SDNode *N) {
const APInt *Divisor;
if (N->getValueType(0).isVector()) {
// Handle splat vectors.
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
if (ConstantSDNode *C = BV->getConstantSplatValue())
Divisor = &C->getAPIntValue();
else
return SDValue();
} else {
Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
}
ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
if (!C)
return SDValue();
// Avoid division by zero.
if (!*Divisor)
if (!C->getAPIntValue())
return SDValue();
std::vector<SDNode*> Built;
SDValue S = TLI.BuildSDIV(N, *Divisor, DAG, LegalOperations, &Built);
SDValue S =
TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built);
for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end();
ii != ee; ++ii)
AddToWorkList(*ii);
for (SDNode *N : Built)
AddToWorkList(N);
return S;
}
@ -11217,28 +11204,20 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) {
/// multiplying by a magic number. See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
SDValue DAGCombiner::BuildUDIV(SDNode *N) {
const APInt *Divisor;
if (N->getValueType(0).isVector()) {
// Handle splat vectors.
BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
if (ConstantSDNode *C = BV->getConstantSplatValue())
Divisor = &C->getAPIntValue();
else
return SDValue();
} else {
Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
}
ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
if (!C)
return SDValue();
// Avoid division by zero.
if (!*Divisor)
if (!C->getAPIntValue())
return SDValue();
std::vector<SDNode*> Built;
SDValue S = TLI.BuildUDIV(N, *Divisor, DAG, LegalOperations, &Built);
SDValue S =
TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built);
for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end();
ii != ee; ++ii)
AddToWorkList(*ii);
for (SDNode *N : Built)
AddToWorkList(N);
return S;
}

View File

@ -151,3 +151,38 @@ define <8 x i32> @test9(<8 x i32> %a) {
; AVX: vpsrad $2
; AVX: vpadd
}
define <8 x i32> @test10(<8 x i32> %a) {
%rem = urem <8 x i32> %a, <i32 7, i32 7, i32 7, i32 7,i32 7, i32 7, i32 7, i32 7>
ret <8 x i32> %rem
; AVX-LABEL: test10:
; AVX: vpermd
; AVX: vpmuludq
; AVX: vshufps $-35
; AVX: vpmuludq
; AVX: vshufps $-35
; AVX: vpsubd
; AVX: vpsrld $1
; AVX: vpadd
; AVX: vpsrld $2
; AVX: vpmulld
}
define <8 x i32> @test11(<8 x i32> %a) {
%rem = srem <8 x i32> %a, <i32 7, i32 7, i32 7, i32 7,i32 7, i32 7, i32 7, i32 7>
ret <8 x i32> %rem
; AVX-LABEL: test11:
; AVX: vpermd
; AVX: vpmuldq
; AVX: vshufps $-35
; AVX: vpmuldq
; AVX: vshufps $-35
; AVX: vpshufd $-40
; AVX: vpadd
; AVX: vpsrld $31
; AVX: vpsrad $2
; AVX: vpadd
; AVX: vpmulld
}