LoopVectorizer: Add support for floating point reductions

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@171812 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Nadav Rotem 2013-01-07 23:13:00 +00:00
parent 59d152197d
commit 111e5fe7e0
2 changed files with 120 additions and 58 deletions

View File

@ -215,10 +215,6 @@ private:
/// broadcast them into a vector.
VectorParts &getVectorValue(Value *V);
/// Get a uniform vector of constant integers. We use this to get
/// vectors of ones and zeros for the reduction code.
Constant* getUniformVector(unsigned Val, Type* ScalarTy);
/// Generate a shuffle sequence that will reverse the vector Vec.
Value *reverseVector(Value *Vec);
@ -325,12 +321,14 @@ public:
/// This enum represents the kinds of reductions that we support.
enum ReductionKind {
NoReduction, ///< Not a reduction.
IntegerAdd, ///< Sum of numbers.
IntegerMult, ///< Product of numbers.
IntegerOr, ///< Bitwise or logical OR of numbers.
IntegerAnd, ///< Bitwise or logical AND of numbers.
IntegerXor ///< Bitwise or logical XOR of numbers.
RK_NoReduction, ///< Not a reduction.
RK_IntegerAdd, ///< Sum of integers.
RK_IntegerMult, ///< Product of integers.
RK_IntegerOr, ///< Bitwise or logical OR of numbers.
RK_IntegerAnd, ///< Bitwise or logical AND of numbers.
RK_IntegerXor, ///< Bitwise or logical XOR of numbers.
RK_FloatAdd, ///< Sum of floats.
RK_FloatMult ///< Product of floats.
};
/// This enum represents the kinds of inductions that we support.
@ -343,8 +341,8 @@ public:
/// This POD struct holds information about reduction variables.
struct ReductionDescriptor {
ReductionDescriptor() : StartValue(0), LoopExitInstr(0), Kind(NoReduction) {
}
ReductionDescriptor() : StartValue(0), LoopExitInstr(0),
Kind(RK_NoReduction) {}
ReductionDescriptor(Value *Start, Instruction *Exit, ReductionKind K)
: StartValue(Start), LoopExitInstr(Exit), Kind(K) {}
@ -790,11 +788,6 @@ InnerLoopVectorizer::getVectorValue(Value *V) {
return WidenMap.get(V);
}
Constant*
InnerLoopVectorizer::getUniformVector(unsigned Val, Type* ScalarTy) {
return ConstantVector::getSplat(VF, ConstantInt::get(ScalarTy, Val, true));
}
Value *InnerLoopVectorizer::reverseVector(Value *Vec) {
assert(Vec->getType()->isVectorTy() && "Invalid type");
SmallVector<Constant*, 8> ShuffleMask;
@ -1215,20 +1208,26 @@ InnerLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
/// This function returns the identity element (or neutral element) for
/// the operation K.
static unsigned
getReductionIdentity(LoopVectorizationLegality::ReductionKind K) {
static Constant*
getReductionIdentity(LoopVectorizationLegality::ReductionKind K, Type *Tp) {
switch (K) {
case LoopVectorizationLegality::IntegerXor:
case LoopVectorizationLegality::IntegerAdd:
case LoopVectorizationLegality::IntegerOr:
case LoopVectorizationLegality:: RK_IntegerXor:
case LoopVectorizationLegality:: RK_IntegerAdd:
case LoopVectorizationLegality:: RK_IntegerOr:
// Adding, Xoring, Oring zero to a number does not change it.
return 0;
case LoopVectorizationLegality::IntegerMult:
return ConstantInt::get(Tp, 0);
case LoopVectorizationLegality:: RK_IntegerMult:
// Multiplying a number by 1 does not change it.
return 1;
case LoopVectorizationLegality::IntegerAnd:
return ConstantInt::get(Tp, 1);
case LoopVectorizationLegality:: RK_IntegerAnd:
// AND-ing a number with an all-1 value does not change it.
return -1;
return ConstantInt::get(Tp, -1, true);
case LoopVectorizationLegality:: RK_FloatMult:
// Multiplying a number by 1 does not change it.
return ConstantFP::get(Tp, 1.0L);
case LoopVectorizationLegality:: RK_FloatAdd:
// Adding zero to a number does not change it.
return ConstantFP::get(Tp, 0.0L);
default:
llvm_unreachable("Unknown reduction kind");
}
@ -1329,8 +1328,8 @@ InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) {
// Find the reduction identity variable. Zero for addition, or, xor,
// one for multiplication, -1 for And.
Constant *Identity = getUniformVector(getReductionIdentity(RdxDesc.Kind),
VecTy->getScalarType());
Constant *Iden = getReductionIdentity(RdxDesc.Kind, VecTy->getScalarType());
Constant *Identity = ConstantVector::getSplat(VF, Iden);
// This vector is the Identity vector where the first element is the
// incoming scalar reduction.
@ -1378,26 +1377,34 @@ InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) {
Value *ReducedPartRdx = RdxParts[0];
for (unsigned part = 1; part < UF; ++part) {
switch (RdxDesc.Kind) {
case LoopVectorizationLegality::IntegerAdd:
case LoopVectorizationLegality::RK_IntegerAdd:
ReducedPartRdx =
Builder.CreateAdd(RdxParts[part], ReducedPartRdx, "add.rdx");
break;
case LoopVectorizationLegality::IntegerMult:
case LoopVectorizationLegality::RK_IntegerMult:
ReducedPartRdx =
Builder.CreateMul(RdxParts[part], ReducedPartRdx, "mul.rdx");
break;
case LoopVectorizationLegality::IntegerOr:
case LoopVectorizationLegality::RK_IntegerOr:
ReducedPartRdx =
Builder.CreateOr(RdxParts[part], ReducedPartRdx, "or.rdx");
break;
case LoopVectorizationLegality::IntegerAnd:
case LoopVectorizationLegality::RK_IntegerAnd:
ReducedPartRdx =
Builder.CreateAnd(RdxParts[part], ReducedPartRdx, "and.rdx");
break;
case LoopVectorizationLegality::IntegerXor:
case LoopVectorizationLegality::RK_IntegerXor:
ReducedPartRdx =
Builder.CreateXor(RdxParts[part], ReducedPartRdx, "xor.rdx");
break;
case LoopVectorizationLegality::RK_FloatMult:
ReducedPartRdx =
Builder.CreateFMul(RdxParts[part], ReducedPartRdx, "fmul.rdx");
break;
case LoopVectorizationLegality::RK_FloatAdd:
ReducedPartRdx =
Builder.CreateFAdd(RdxParts[part], ReducedPartRdx, "fadd.rdx");
break;
default:
llvm_unreachable("Unknown reduction operation");
}
@ -1428,21 +1435,27 @@ InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) {
// Emit the operation on the shuffled value.
switch (RdxDesc.Kind) {
case LoopVectorizationLegality::IntegerAdd:
case LoopVectorizationLegality::RK_IntegerAdd:
TmpVec = Builder.CreateAdd(TmpVec, Shuf, "add.rdx");
break;
case LoopVectorizationLegality::IntegerMult:
case LoopVectorizationLegality::RK_IntegerMult:
TmpVec = Builder.CreateMul(TmpVec, Shuf, "mul.rdx");
break;
case LoopVectorizationLegality::IntegerOr:
case LoopVectorizationLegality::RK_IntegerOr:
TmpVec = Builder.CreateOr(TmpVec, Shuf, "or.rdx");
break;
case LoopVectorizationLegality::IntegerAnd:
case LoopVectorizationLegality::RK_IntegerAnd:
TmpVec = Builder.CreateAnd(TmpVec, Shuf, "and.rdx");
break;
case LoopVectorizationLegality::IntegerXor:
case LoopVectorizationLegality::RK_IntegerXor:
TmpVec = Builder.CreateXor(TmpVec, Shuf, "xor.rdx");
break;
case LoopVectorizationLegality::RK_FloatMult:
TmpVec = Builder.CreateFMul(TmpVec, Shuf, "fmul.rdx");
break;
case LoopVectorizationLegality::RK_FloatAdd:
TmpVec = Builder.CreateFAdd(TmpVec, Shuf, "fadd.rdx");
break;
default:
llvm_unreachable("Unknown reduction operation");
}
@ -2074,6 +2087,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
// Check that this PHI type is allowed.
if (!Phi->getType()->isIntegerTy() &&
!Phi->getType()->isFloatingPointTy() &&
!Phi->getType()->isPointerTy()) {
DEBUG(dbgs() << "LV: Found an non-int non-pointer PHI.\n");
return false;
@ -2105,26 +2119,34 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
continue;
}
if (AddReductionVar(Phi, IntegerAdd)) {
if (AddReductionVar(Phi, RK_IntegerAdd)) {
DEBUG(dbgs() << "LV: Found an ADD reduction PHI."<< *Phi <<"\n");
continue;
}
if (AddReductionVar(Phi, IntegerMult)) {
if (AddReductionVar(Phi, RK_IntegerMult)) {
DEBUG(dbgs() << "LV: Found a MUL reduction PHI."<< *Phi <<"\n");
continue;
}
if (AddReductionVar(Phi, IntegerOr)) {
if (AddReductionVar(Phi, RK_IntegerOr)) {
DEBUG(dbgs() << "LV: Found an OR reduction PHI."<< *Phi <<"\n");
continue;
}
if (AddReductionVar(Phi, IntegerAnd)) {
if (AddReductionVar(Phi, RK_IntegerAnd)) {
DEBUG(dbgs() << "LV: Found an AND reduction PHI."<< *Phi <<"\n");
continue;
}
if (AddReductionVar(Phi, IntegerXor)) {
if (AddReductionVar(Phi, RK_IntegerXor)) {
DEBUG(dbgs() << "LV: Found a XOR reduction PHI."<< *Phi <<"\n");
continue;
}
if (AddReductionVar(Phi, RK_FloatMult)) {
DEBUG(dbgs() << "LV: Found an FMult reduction PHI."<< *Phi <<"\n");
continue;
}
if (AddReductionVar(Phi, RK_FloatAdd)) {
DEBUG(dbgs() << "LV: Found an FAdd reduction PHI."<< *Phi <<"\n");
continue;
}
DEBUG(dbgs() << "LV: Found an unidentified PHI."<< *Phi <<"\n");
return false;
@ -2419,6 +2441,8 @@ bool LoopVectorizationLegality::AddReductionVar(PHINode *Phi,
// This includes users of the reduction, variables (which form a cycle
// which ends in the phi node).
Instruction *ExitInstruction = 0;
// Indicates that we found a binary operation in our scan.
bool FoundBinOp = false;
// Iter is our iterator. We start with the PHI node and scan for all of the
// users of this instruction. All users must be instructions that can be
@ -2436,6 +2460,9 @@ bool LoopVectorizationLegality::AddReductionVar(PHINode *Phi,
// Did we reach the initial PHI node already ?
bool FoundStartPHI = false;
// Is this a bin op ?
FoundBinOp |= !isa<PHINode>(Iter);
// For each of the *users* of iter.
for (Value::use_iterator it = Iter->use_begin(), e = Iter->use_end();
it != e; ++it) {
@ -2475,7 +2502,7 @@ bool LoopVectorizationLegality::AddReductionVar(PHINode *Phi,
// Reductions of instructions such as Div, and Sub is only
// possible if the LHS is the reduction variable.
if (!U->isCommutative() && U->getOperand(0) != Iter)
if (!U->isCommutative() && !isa<PHINode>(U) && U->getOperand(0) != Iter)
return false;
Iter = U;
@ -2484,46 +2511,52 @@ bool LoopVectorizationLegality::AddReductionVar(PHINode *Phi,
// We found a reduction var if we have reached the original
// phi node and we only have a single instruction with out-of-loop
// users.
if (FoundStartPHI && ExitInstruction) {
if (FoundStartPHI) {
// This instruction is allowed to have out-of-loop users.
AllowedExit.insert(ExitInstruction);
// Save the description of this reduction variable.
ReductionDescriptor RD(RdxStart, ExitInstruction, Kind);
Reductions[Phi] = RD;
return true;
// We've ended the cycle. This is a reduction variable if we have an
// outside user and it has a binary op.
return FoundBinOp && ExitInstruction;
}
// If we've reached the start PHI but did not find an outside user then
// this is dead code. Abort.
if (FoundStartPHI)
return false;
}
}
bool
LoopVectorizationLegality::isReductionInstr(Instruction *I,
ReductionKind Kind) {
bool FP = I->getType()->isFloatingPointTy();
bool FastMath = (FP && I->isCommutative() && I->isAssociative());
switch (I->getOpcode()) {
default:
return false;
case Instruction::PHI:
if (FP && (Kind != RK_FloatMult && Kind != RK_FloatAdd))
return false;
// possibly.
return true;
case Instruction::Sub:
case Instruction::Add:
return Kind == IntegerAdd;
return Kind == RK_IntegerAdd;
case Instruction::SDiv:
case Instruction::UDiv:
case Instruction::Mul:
return Kind == IntegerMult;
return Kind == RK_IntegerMult;
case Instruction::And:
return Kind == IntegerAnd;
return Kind == RK_IntegerAnd;
case Instruction::Or:
return Kind == IntegerOr;
return Kind == RK_IntegerOr;
case Instruction::Xor:
return Kind == IntegerXor;
}
return Kind == RK_IntegerXor;
case Instruction::FMul:
return Kind == RK_FloatMult && FastMath;
case Instruction::FAdd:
return Kind == RK_FloatAdd && FastMath;
}
}
LoopVectorizationLegality::InductionKind

View File

@ -0,0 +1,29 @@
; RUN: opt < %s -loop-vectorize -force-vector-unroll=1 -force-vector-width=4 -dce -instcombine -S | FileCheck %s
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
target triple = "x86_64-apple-macosx10.8.0"
;CHECK: @foo
;CHECK: fadd <4 x float>
;CHECK: ret
define float @foo(float* nocapture %A, i32* nocapture %n) nounwind uwtable readonly ssp {
entry:
br label %for.body
for.body: ; preds = %for.body, %entry
%indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
%sum.04 = phi float [ 0.000000e+00, %entry ], [ %add, %for.body ]
%arrayidx = getelementptr inbounds float* %A, i64 %indvars.iv
%0 = load float* %arrayidx, align 4, !tbaa !0
%add = fadd fast float %sum.04, %0
%indvars.iv.next = add i64 %indvars.iv, 1
%lftr.wideiv = trunc i64 %indvars.iv.next to i32
%exitcond = icmp eq i32 %lftr.wideiv, 200
br i1 %exitcond, label %for.end, label %for.body
for.end: ; preds = %for.body
ret float %add
}
!0 = metadata !{metadata !"float", metadata !1}
!1 = metadata !{metadata !"omnipotent char", metadata !2}
!2 = metadata !{metadata !"Simple C/C++ TBAA"}