From d717e202a2d50d47b96534dbf67b8aa6ea01b912 Mon Sep 17 00:00:00 2001 From: Arnold Schwaighofer Date: Fri, 19 Apr 2013 21:03:36 +0000 Subject: [PATCH] LoopVectorizer: Use matcher from PatternMatch.h for the min/max patterns Also make some static function class functions to avoid having to mention the class namespace for enums all the time. No functionality change intended. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@179886 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Vectorize/LoopVectorize.cpp | 200 ++++++++++----------- 1 file changed, 99 insertions(+), 101 deletions(-) diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 6ca76229632..0c88ba7835d 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -78,6 +78,7 @@ #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/PatternMatch.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Scalar.h" @@ -87,6 +88,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; static cl::opt VectorizationFactor("force-vector-width", cl::init(0), cl::Hidden, @@ -357,14 +359,23 @@ public: IK_ReversePtrInduction ///< Reverse ptr indvar. Step = - sizeof(elem). }; + // This enum represents the kind of minmax reduction. + enum MinMaxReductionKind { + MRK_Invalid, + MRK_UIntMin, + MRK_UIntMax, + MRK_SIntMin, + MRK_SIntMax + }; + /// This POD struct holds information about reduction variables. struct ReductionDescriptor { ReductionDescriptor() : StartValue(0), LoopExitInstr(0), - Kind(RK_NoReduction) {} + Kind(RK_NoReduction), MinMaxKind(MRK_Invalid) {} ReductionDescriptor(Value *Start, Instruction *Exit, ReductionKind K, - CmpInst::Predicate P) - : StartValue(Start), LoopExitInstr(Exit), Kind(K), MinMaxPred(P) {} + MinMaxReductionKind MK) + : StartValue(Start), LoopExitInstr(Exit), Kind(K), MinMaxKind(MK) {} // The starting value of the reduction. // It does not have to be zero! @@ -374,16 +385,16 @@ public: // The kind of the reduction. ReductionKind Kind; // If this a min/max reduction the kind of reduction. - CmpInst::Predicate MinMaxPred; + MinMaxReductionKind MinMaxKind; }; /// This POD struct holds information about a potential reduction operation. struct ReductionInstDesc { ReductionInstDesc(bool IsRedux, Instruction *I) : - IsReduction(IsRedux), PatternLastInst(I), Predicate(ICmpInst::ICMP_EQ) {} + IsReduction(IsRedux), PatternLastInst(I), MinMaxKind(MRK_Invalid) {} - ReductionInstDesc(Instruction *I, CmpInst::Predicate P) : - IsReduction(true), PatternLastInst(I), Predicate(P) {} + ReductionInstDesc(Instruction *I, MinMaxReductionKind K) : + IsReduction(true), PatternLastInst(I), MinMaxKind(K) {} // Is this instruction a reduction candidate. bool IsReduction; @@ -391,7 +402,7 @@ public: // pattern), or the current reduction instruction otherwise. Instruction *PatternLastInst; // If this is a min/max pattern the comparison predicate. - CmpInst::Predicate Predicate; + MinMaxReductionKind MinMaxKind; }; // This POD struct holds information about the memory runtime legality @@ -482,6 +493,11 @@ public: /// Returns the information that we collected about runtime memory check. RuntimePointerCheck *getRuntimePointerCheck() { return &PtrRtCheck; } + + /// This function returns the identity element (or neutral element) for + /// the operation K. + static Constant *getReductionIdentity(ReductionKind K, Type *Tp, + MinMaxReductionKind MinMaxK); private: /// Check if a single basic block loop is vectorizable. /// At this point we know that this is a loop with a constant trip count @@ -514,7 +530,11 @@ private: /// compare instruction to the select instruction and stores this pointer in /// 'PatternLastInst' member of the returned struct. ReductionInstDesc isReductionInstr(Instruction *I, ReductionKind Kind, - ReductionInstDesc Desc); + ReductionInstDesc &Desc); + /// Returns true if the instruction is a Select(ICmp(X, Y), X, Y) instruction + /// pattern corresponding to a min(X, Y) or max(X, Y). + static ReductionInstDesc isMinMaxSelectCmpPattern(Instruction *I, + ReductionInstDesc &Prev); /// Returns the induction kind of Phi. This function may return NoInduction /// if the PHI is not an induction variable. InductionKind isInductionVariable(PHINode *Phi); @@ -1461,44 +1481,40 @@ InnerLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) { /// This function returns the identity element (or neutral element) for /// the operation K. -static Constant* -getReductionIdentity(LoopVectorizationLegality::ReductionKind K, Type *Tp, - CmpInst::Predicate Pred) { +Constant* +LoopVectorizationLegality::getReductionIdentity(ReductionKind K, Type *Tp, + MinMaxReductionKind MinMaxK) { switch (K) { - case LoopVectorizationLegality:: RK_IntegerXor: - case LoopVectorizationLegality:: RK_IntegerAdd: - case LoopVectorizationLegality:: RK_IntegerOr: + case RK_IntegerXor: + case RK_IntegerAdd: + case RK_IntegerOr: // Adding, Xoring, Oring zero to a number does not change it. return ConstantInt::get(Tp, 0); - case LoopVectorizationLegality:: RK_IntegerMult: + case RK_IntegerMult: // Multiplying a number by 1 does not change it. return ConstantInt::get(Tp, 1); - case LoopVectorizationLegality:: RK_IntegerAnd: + case RK_IntegerAnd: // AND-ing a number with an all-1 value does not change it. return ConstantInt::get(Tp, -1, true); - case LoopVectorizationLegality:: RK_FloatMult: + case RK_FloatMult: // Multiplying a number by 1 does not change it. return ConstantFP::get(Tp, 1.0L); - case LoopVectorizationLegality:: RK_FloatAdd: + case RK_FloatAdd: // Adding zero to a number does not change it. return ConstantFP::get(Tp, 0.0L); - case LoopVectorizationLegality:: RK_IntegerMinMax: - switch(Pred) { + case RK_IntegerMinMax: + switch(MinMaxK) { default: llvm_unreachable("Unknown min/max predicate"); - case CmpInst::ICMP_ULT: - case CmpInst::ICMP_ULE: + case MRK_UIntMin: return ConstantInt::getAllOnesValue(Tp); - case CmpInst::ICMP_UGT: - case CmpInst::ICMP_UGE: + case MRK_UIntMax: return ConstantInt::get(Tp, 0); - case CmpInst::ICMP_SLT: - case CmpInst::ICMP_SLE: { + case MRK_SIntMin: { unsigned BitWidth = Tp->getPrimitiveSizeInBits(); return ConstantInt::get(Tp->getContext(), APInt::getSignedMaxValue(BitWidth)); } - case CmpInst::ICMP_SGT: - case CmpInst::ICMP_SGE: { + case LoopVectorizationLegality::MRK_SIntMax: { unsigned BitWidth = Tp->getPrimitiveSizeInBits(); return ConstantInt::get(Tp->getContext(), APInt::getSignedMinValue(BitWidth)); @@ -1638,8 +1654,26 @@ getReductionBinOp(LoopVectorizationLegality::ReductionKind Kind) { } } -Value *createMinMaxOp(IRBuilder<> &Builder, ICmpInst::Predicate P, Value *Left, +Value *createMinMaxOp(IRBuilder<> &Builder, + LoopVectorizationLegality::MinMaxReductionKind RK, + Value *Left, Value *Right) { + CmpInst::Predicate P = CmpInst::ICMP_NE; + switch (RK) { + default: + llvm_unreachable("Unknown min/max reduction kind"); + case LoopVectorizationLegality::MRK_UIntMin: + P = CmpInst::ICMP_ULT; + break; + case LoopVectorizationLegality::MRK_UIntMax: + P = CmpInst::ICMP_UGT; + break; + case LoopVectorizationLegality::MRK_SIntMin: + P = CmpInst::ICMP_SLT; + break; + case LoopVectorizationLegality::MRK_SIntMax: + P = CmpInst::ICMP_SGT; + } Value *Cmp = Builder.CreateICmp(P, Left, Right, "rdx.minmax.cmp"); Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select"); return Select; @@ -1708,8 +1742,10 @@ InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) { // Find the reduction identity variable. Zero for addition, or, xor, // one for multiplication, -1 for And. - Constant *Iden = getReductionIdentity(RdxDesc.Kind, VecTy->getScalarType(), - RdxDesc.MinMaxPred); + Constant *Iden = + LoopVectorizationLegality::getReductionIdentity(RdxDesc.Kind, + VecTy->getScalarType(), + RdxDesc.MinMaxKind); Constant *Identity = ConstantVector::getSplat(VF, Iden); // This vector is the Identity vector where the first element is the @@ -1764,7 +1800,7 @@ InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) { RdxParts[part], ReducedPartRdx, "bin.rdx"); else - ReducedPartRdx = createMinMaxOp(Builder, RdxDesc.MinMaxPred, + ReducedPartRdx = createMinMaxOp(Builder, RdxDesc.MinMaxKind, ReducedPartRdx, RdxParts[part]); } @@ -1794,7 +1830,7 @@ InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) { TmpVec = Builder.CreateBinOp((Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx"); else - TmpVec = createMinMaxOp(Builder, RdxDesc.MinMaxPred, TmpVec, Shuf); + TmpVec = createMinMaxOp(Builder, RdxDesc.MinMaxKind, TmpVec, Shuf); } // The result is in the first element of the vector. @@ -2894,7 +2930,7 @@ bool LoopVectorizationLegality::AddReductionVar(PHINode *Phi, // Save the description of this reduction variable. ReductionDescriptor RD(RdxStart, ExitInstruction, Kind, - ReduxDesc.Predicate); + ReduxDesc.MinMaxKind); Reductions[Phi] = RD; // We've ended the cycle. This is a reduction variable if we have an // outside user and it has a binary op. @@ -2905,90 +2941,52 @@ bool LoopVectorizationLegality::AddReductionVar(PHINode *Phi, return false; } -static CmpInst::Predicate getPredicateSense(CmpInst::Predicate P, - bool ShouldRevert) { - if (!ShouldRevert) return P; - - switch(P) { - default: - llvm_unreachable("Unknown predicate sense"); - case CmpInst::ICMP_UGT: - case CmpInst::ICMP_UGE: - return CmpInst::ICMP_ULT; - case CmpInst::ICMP_SGT: - case CmpInst::ICMP_SGE: - return CmpInst::ICMP_SLT; - case CmpInst::ICMP_ULT: - case CmpInst::ICMP_ULE: - return CmpInst::ICMP_UGT; - case CmpInst::ICMP_SLT: - case CmpInst::ICMP_SLE: - return CmpInst::ICMP_SGT; - } -} - /// Returns true if the instruction is a Select(ICmp(X, Y), X, Y) instruction /// pattern corresponding to a min(X, Y) or max(X, Y). -static LoopVectorizationLegality::ReductionInstDesc -isMinMaxSelectCmpPattern(Instruction *I) { +LoopVectorizationLegality::ReductionInstDesc +LoopVectorizationLegality::isMinMaxSelectCmpPattern(Instruction *I, ReductionInstDesc &Prev) { assert((isa(I) || isa(I)) && "Expect a select instruction"); ICmpInst *Cmp = 0; SelectInst *Select = 0; - // Look for a select(icmp(),...) pattern. Only handle integer reductions for - // now. - if ((Select = dyn_cast(I))) { - if (!(Cmp = dyn_cast(I->getOperand(0)))) - return LoopVectorizationLegality::ReductionInstDesc(false, I); - // Only handle the single user case - if (!Cmp->hasOneUse()) - return LoopVectorizationLegality::ReductionInstDesc(false, I); - } else if ((Cmp = dyn_cast(I))) { - // Only handle the single user case. - if (!Cmp->hasOneUse()) - return LoopVectorizationLegality::ReductionInstDesc(false, I); - // Look for the select. - if (!(Select = dyn_cast(*I->use_begin()))) - return LoopVectorizationLegality::ReductionInstDesc(false, I); - // Compare must be the first operand of the select. - if (Select->getOperand(0) != Cmp) - return LoopVectorizationLegality::ReductionInstDesc(false, I); + // We must handle the select(cmp()) as a single instruction. Advance to the + // select. + if ((Cmp = dyn_cast(I))) { + if (!Cmp->hasOneUse() || !(Select = dyn_cast(*I->use_begin()))) + return ReductionInstDesc(false, I); + return ReductionInstDesc(Select, Prev.MinMaxKind); } - CmpInst::Predicate Pred = Cmp->getPredicate(); - - // Only (u/s)lt/gt/ge/le are min or max patterns. - if (Pred == CmpInst::ICMP_EQ || - Pred == CmpInst::ICMP_NE) - return LoopVectorizationLegality::ReductionInstDesc(false, I); - - Value *SelectOp1 = Select->getOperand(1); - Value *SelectOp2 = Select->getOperand(2); + // Only handle single use cases for now. + if (!(Select = dyn_cast(I))) + return ReductionInstDesc(false, I); + if (!(Cmp = dyn_cast(I->getOperand(0)))) + return ReductionInstDesc(false, I); + if (!Cmp->hasOneUse()) + return ReductionInstDesc(false, I); Value *CmpLeft = Cmp->getOperand(0); Value *CmpRight = Cmp->getOperand(1); - // Can have reversed sense. - // select(slt(X, Y), Y, X) == select(sge(X, Y), X, Y). - bool IsInverted = (SelectOp2 == CmpLeft && SelectOp1 == CmpRight); - bool IsMinMaxPattern = (SelectOp1 == CmpLeft && SelectOp2 == CmpRight) || - IsInverted; + // Look for a min/max pattern. + if (m_UMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return ReductionInstDesc(Select, MRK_UIntMin); + else if (m_UMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return ReductionInstDesc(Select, MRK_UIntMax); + else if (m_SMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return ReductionInstDesc(Select, MRK_SIntMax); + else if (m_SMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return ReductionInstDesc(Select, MRK_SIntMin); - // Advance the instruction pointer from the icmp to the select instruction. - if (IsMinMaxPattern) { - CmpInst::Predicate P = getPredicateSense(Pred, IsInverted); - return LoopVectorizationLegality::ReductionInstDesc(Select, P); - } - - return LoopVectorizationLegality::ReductionInstDesc(false, I); + return ReductionInstDesc(false, I); } LoopVectorizationLegality::ReductionInstDesc LoopVectorizationLegality::isReductionInstr(Instruction *I, ReductionKind Kind, - ReductionInstDesc Desc) { + ReductionInstDesc &Prev) { bool FP = I->getType()->isFloatingPointTy(); bool FastMath = (FP && I->isCommutative() && I->isAssociative()); switch (I->getOpcode()) { @@ -2997,7 +2995,7 @@ LoopVectorizationLegality::isReductionInstr(Instruction *I, case Instruction::PHI: if (FP && (Kind != RK_FloatMult && Kind != RK_FloatAdd)) return ReductionInstDesc(false, I); - return ReductionInstDesc(I, Desc.Predicate); + return ReductionInstDesc(I, Prev.MinMaxKind); case Instruction::Sub: case Instruction::Add: return ReductionInstDesc(Kind == RK_IntegerAdd, I); @@ -3017,7 +3015,7 @@ LoopVectorizationLegality::isReductionInstr(Instruction *I, case Instruction::Select: if (Kind != RK_IntegerMinMax) return ReductionInstDesc(false, I); - return isMinMaxSelectCmpPattern(I); + return isMinMaxSelectCmpPattern(I, Prev); } }