diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index cb408a137ea..c4079e37a11 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -31,10 +31,12 @@ #include "llvm/Pass.h" #include "llvm/Assembly/Writer.h" #include "llvm/Support/CFG.h" +#include "llvm/Support/IRBuilder.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ValueHandle.h" #include "llvm/Support/raw_ostream.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/DenseMap.h" #include @@ -71,6 +73,45 @@ static void PrintOps(Instruction *I, const SmallVectorImpl &Ops) { } #endif +namespace { + /// \brief Utility class representing a base and exponent pair which form one + /// factor of some product. + struct Factor { + Value *Base; + unsigned Power; + + Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {} + + /// \brief Sort factors by their Base. + struct BaseSorter { + bool operator()(const Factor &LHS, const Factor &RHS) { + return LHS.Base < RHS.Base; + } + }; + + /// \brief Compare factors for equal bases. + struct BaseEqual { + bool operator()(const Factor &LHS, const Factor &RHS) { + return LHS.Base == RHS.Base; + } + }; + + /// \brief Sort factors in descending order by their power. + struct PowerDescendingSorter { + bool operator()(const Factor &LHS, const Factor &RHS) { + return LHS.Power > RHS.Power; + } + }; + + /// \brief Compare factors for equal powers. + struct PowerEqual { + bool operator()(const Factor &LHS, const Factor &RHS) { + return LHS.Power == RHS.Power; + } + }; + }; +} + namespace { class Reassociate : public FunctionPass { DenseMap RankMap; @@ -98,6 +139,11 @@ namespace { Value *OptimizeExpression(BinaryOperator *I, SmallVectorImpl &Ops); Value *OptimizeAdd(Instruction *I, SmallVectorImpl &Ops); + bool collectMultiplyFactors(SmallVectorImpl &Ops, + SmallVectorImpl &Factors); + Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder, + SmallVectorImpl &Factors); + Value *OptimizeMul(BinaryOperator *I, SmallVectorImpl &Ops); void LinearizeExprTree(BinaryOperator *I, SmallVectorImpl &Ops); void LinearizeExpr(BinaryOperator *I); Value *RemoveFactorFromExpression(Value *V, Value *Factor); @@ -888,6 +934,199 @@ Value *Reassociate::OptimizeAdd(Instruction *I, return 0; } +namespace { + /// \brief Predicate tests whether a ValueEntry's op is in a map. + struct IsValueInMap { + const DenseMap ⤅ + + IsValueInMap(const DenseMap &Map) : Map(Map) {} + + bool operator()(const ValueEntry &Entry) { + return Map.find(Entry.Op) != Map.end(); + } + }; +} + +/// \brief Build up a vector of value/power pairs factoring a product. +/// +/// Given a series of multiplication operands, build a vector of factors and +/// the powers each is raised to when forming the final product. Sort them in +/// the order of descending power. +/// +/// (x*x) -> [(x, 2)] +/// ((x*x)*x) -> [(x, 3)] +/// ((((x*y)*x)*y)*x) -> [(x, 3), (y, 2)] +/// +/// \returns Whether any factors have a power greater than one. +bool Reassociate::collectMultiplyFactors(SmallVectorImpl &Ops, + SmallVectorImpl &Factors) { + unsigned FactorPowerSum = 0; + DenseMap FactorCounts; + for (unsigned LastIdx = 0, Idx = 0, Size = Ops.size(); Idx < Size; ++Idx) { + // Note that 'use_empty' uses means the only use is in the linearized tree + // represented by Ops -- we remove the values from the actual operations to + // reduce their use count. + if (!Ops[Idx].Op->use_empty()) { + if (LastIdx == Idx) + ++LastIdx; + continue; + } + if (LastIdx == Idx || Ops[LastIdx].Op != Ops[Idx].Op) { + LastIdx = Idx; + continue; + } + // Track for simplification all factors which occur 2 or more times. + DenseMap::iterator CountIt; + bool Inserted; + llvm::tie(CountIt, Inserted) + = FactorCounts.insert(std::make_pair(Ops[Idx].Op, 2)); + if (Inserted) { + FactorPowerSum += 2; + Factors.push_back(Factor(Ops[Idx].Op, 2)); + } else { + ++CountIt->second; + ++FactorPowerSum; + } + } + // We can only simplify factors if the sum of the powers of our simplifiable + // factors is 4 or higher. When that is the case, we will *always* have + // a simplification. This is an important invariant to prevent cyclicly + // trying to simplify already minimal formations. + if (FactorPowerSum < 4) + return false; + + // Remove all the operands which are in the map. + Ops.erase(std::remove_if(Ops.begin(), Ops.end(), IsValueInMap(FactorCounts)), + Ops.end()); + + // Record the adjusted power for the simplification factors. We add back into + // the Ops list any values with an odd power, and make the power even. This + // allows the outer-most multiplication tree to remain in tact during + // simplification. + unsigned OldOpsSize = Ops.size(); + for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) { + Factors[Idx].Power = FactorCounts[Factors[Idx].Base]; + if (Factors[Idx].Power & 1) { + Ops.push_back(ValueEntry(getRank(Factors[Idx].Base), Factors[Idx].Base)); + --Factors[Idx].Power; + --FactorPowerSum; + } + } + // None of the adjustments above should have reduced the sum of factor powers + // below our mininum of '4'. + assert(FactorPowerSum >= 4); + + // Patch up the sort of the ops vector by sorting the factors we added back + // onto the back, and merging the two sequences. + if (OldOpsSize != Ops.size()) { + SmallVectorImpl::iterator MiddleIt = Ops.begin() + OldOpsSize; + std::sort(MiddleIt, Ops.end()); + std::inplace_merge(Ops.begin(), MiddleIt, Ops.end()); + } + + std::sort(Factors.begin(), Factors.end(), Factor::PowerDescendingSorter()); + return true; +} + +/// \brief Build a tree of multiplies, computing the product of Ops. +static Value *buildMultiplyTree(IRBuilder<> &Builder, + SmallVectorImpl &Ops) { + if (Ops.size() == 1) + return Ops.back(); + + Value *LHS = Ops.pop_back_val(); + do { + LHS = Builder.CreateMul(LHS, Ops.pop_back_val()); + } while (!Ops.empty()); + + return LHS; +} + +/// \brief Build a minimal multiplication DAG for (a^x)*(b^y)*(c^z)*... +/// +/// Given a vector of values raised to various powers, where no two values are +/// equal and the powers are sorted in decreasing order, compute the minimal +/// DAG of multiplies to compute the final product, and return that product +/// value. +Value *Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder, + SmallVectorImpl &Factors) { + assert(Factors[0].Power); + SmallVector OuterProduct; + for (unsigned LastIdx = 0, Idx = 1, Size = Factors.size(); + Idx < Size && Factors[Idx].Power > 0; ++Idx) { + if (Factors[Idx].Power != Factors[LastIdx].Power) { + LastIdx = Idx; + continue; + } + + // We want to multiply across all the factors with the same power so that + // we can raise them to that power as a single entity. Build a mini tree + // for that. + SmallVector InnerProduct; + InnerProduct.push_back(Factors[LastIdx].Base); + do { + InnerProduct.push_back(Factors[Idx].Base); + ++Idx; + } while (Idx < Size && Factors[Idx].Power == Factors[LastIdx].Power); + + // Reset the base value of the first factor to the new expression tree. + // We'll remove all the factors with the same power in a second pass. + Factors[LastIdx].Base + = ReassociateExpression( + cast(buildMultiplyTree(Builder, InnerProduct))); + + LastIdx = Idx; + } + // Unique factors with equal powers -- we've folded them into the first one's + // base. + Factors.erase(std::unique(Factors.begin(), Factors.end(), + Factor::PowerEqual()), + Factors.end()); + + // Iteratively collect the base of each factor with an add power into the + // outer product, and halve each power in preparation for squaring the + // expression. + for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) { + if (Factors[Idx].Power & 1) + OuterProduct.push_back(Factors[Idx].Base); + Factors[Idx].Power >>= 1; + } + if (Factors[0].Power) { + Value *SquareRoot = buildMinimalMultiplyDAG(Builder, Factors); + OuterProduct.push_back(SquareRoot); + OuterProduct.push_back(SquareRoot); + } + if (OuterProduct.size() == 1) + return OuterProduct.front(); + + return ReassociateExpression( + cast(buildMultiplyTree(Builder, OuterProduct))); +} + +Value *Reassociate::OptimizeMul(BinaryOperator *I, + SmallVectorImpl &Ops) { + // We can only optimize the multiplies when there is a chain of more than + // three, such that a balanced tree might require fewer total multiplies. + if (Ops.size() < 4) + return 0; + + // Try to turn linear trees of multiplies without other uses of the + // intermediate stages into minimal multiply DAGs with perfect sub-expression + // re-use. + SmallVector Factors; + if (!collectMultiplyFactors(Ops, Factors)) + return 0; // All distinct factors, so nothing left for us to do. + + IRBuilder<> Builder(I); + Value *V = buildMinimalMultiplyDAG(Builder, Factors); + if (Ops.empty()) + return V; + + ValueEntry NewEntry = ValueEntry(getRank(V), V); + Ops.insert(std::lower_bound(Ops.begin(), Ops.end(), NewEntry), NewEntry); + return 0; +} + Value *Reassociate::OptimizeExpression(BinaryOperator *I, SmallVectorImpl &Ops) { // Now that we have the linearized expression tree, try to optimize it. @@ -937,30 +1176,28 @@ Value *Reassociate::OptimizeExpression(BinaryOperator *I, // Handle destructive annihilation due to identities between elements in the // argument list here. + unsigned NumOps = Ops.size(); switch (Opcode) { default: break; case Instruction::And: case Instruction::Or: - case Instruction::Xor: { - unsigned NumOps = Ops.size(); + case Instruction::Xor: if (Value *Result = OptimizeAndOrXor(Opcode, Ops)) return Result; - IterateOptimization |= Ops.size() != NumOps; break; - } - case Instruction::Add: { - unsigned NumOps = Ops.size(); + case Instruction::Add: if (Value *Result = OptimizeAdd(I, Ops)) return Result; - IterateOptimization |= Ops.size() != NumOps; - } - break; - //case Instruction::Mul: + + case Instruction::Mul: + if (Value *Result = OptimizeMul(I, Ops)) + return Result; + break; } - if (IterateOptimization) + if (IterateOptimization || Ops.size() != NumOps) return OptimizeExpression(I, Ops); return 0; } diff --git a/test/Transforms/Reassociate/mulfactor.ll b/test/Transforms/Reassociate/mulfactor.ll index 5e6fbeb1cac..6c099b43b36 100644 --- a/test/Transforms/Reassociate/mulfactor.ll +++ b/test/Transforms/Reassociate/mulfactor.ll @@ -33,3 +33,102 @@ entry: ret i32 %d } +define i32 @test3(i32 %x) { +; (x^8) +; CHECK: @test3 +; CHECK: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: ret + +entry: + %a = mul i32 %x, %x + %b = mul i32 %a, %x + %c = mul i32 %b, %x + %d = mul i32 %c, %x + %e = mul i32 %d, %x + %f = mul i32 %e, %x + %g = mul i32 %f, %x + ret i32 %g +} + +define i32 @test4(i32 %x) { +; (x^7) +; CHECK: @test4 +; CHECK: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: ret + +entry: + %a = mul i32 %x, %x + %b = mul i32 %a, %x + %c = mul i32 %b, %x + %d = mul i32 %c, %x + %e = mul i32 %d, %x + %f = mul i32 %e, %x + ret i32 %f +} + +define i32 @test5(i32 %x, i32 %y) { +; (x^4) * (y^2) +; CHECK: @test5 +; CHECK: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: ret + +entry: + %a = mul i32 %x, %y + %b = mul i32 %a, %y + %c = mul i32 %b, %x + %d = mul i32 %c, %x + %e = mul i32 %d, %x + ret i32 %e +} + +define i32 @test6(i32 %x, i32 %y, i32 %z) { +; (x^5) * (y^3) * z +; CHECK: @test6 +; CHECK: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: ret + +entry: + %a = mul i32 %x, %y + %b = mul i32 %a, %x + %c = mul i32 %b, %y + %d = mul i32 %c, %x + %e = mul i32 %d, %y + %f = mul i32 %e, %x + %g = mul i32 %f, %z + %h = mul i32 %g, %x + ret i32 %h +} + +define i32 @test7(i32 %x, i32 %y, i32 %z) { +; (x^4) * (y^3) * (z^2) +; CHECK: @test7 +; CHECK: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: mul +; CHECK-NEXT: ret + +entry: + %a = mul i32 %y, %x + %b = mul i32 %a, %z + %c = mul i32 %b, %z + %d = mul i32 %c, %x + %e = mul i32 %d, %y + %f = mul i32 %e, %y + %g = mul i32 %f, %x + %h = mul i32 %g, %x + ret i32 %h +}