diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index fe63381077f..e0528f76670 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -88,7 +88,7 @@ namespace { private: void BuildRankMap(Function &F); unsigned getRank(Value *V); - void ReassociateExpression(BinaryOperator *I); + Value *ReassociateExpression(BinaryOperator *I); void RewriteExprTree(BinaryOperator *I, SmallVectorImpl &Ops, unsigned Idx = 0); Value *OptimizeExpression(BinaryOperator *I, @@ -111,10 +111,13 @@ FunctionPass *llvm::createReassociatePass() { return new Reassociate(); } void Reassociate::RemoveDeadBinaryOp(Value *V) { Instruction *Op = dyn_cast(V); - if (!Op || !isa(Op) || !isa(Op) || !Op->use_empty()) + if (!Op || !isa(Op) || !Op->use_empty()) return; Value *LHS = Op->getOperand(0), *RHS = Op->getOperand(1); + + ValueRankMap.erase(Op); + Op->eraseFromParent(); RemoveDeadBinaryOp(LHS); RemoveDeadBinaryOp(RHS); } @@ -602,15 +605,57 @@ static Value *OptimizeAndOrXor(unsigned Opcode, /// is returned, otherwise the Ops list is mutated as necessary. Value *Reassociate::OptimizeAdd(Instruction *I, SmallVectorImpl &Ops) { + SmallPtrSet OperandsSeen; + +Restart: + OperandsSeen.clear(); + // Scan the operand lists looking for X and -X pairs. If we find any, we - // can simplify the expression. X+-X == 0. + // can simplify the expression. X+-X == 0. While we're at it, scan for any + // duplicates. We want to canonicalize Y+Y+Y+Z -> 3*Y+Z. for (unsigned i = 0, e = Ops.size(); i != e; ++i) { - assert(i < Ops.size()); + Value *TheOp = Ops[i].Op; + // Check to see if we've seen this operand before. If so, we factor all + // instances of the operand together. + if (!OperandsSeen.insert(TheOp)) { + // Rescan the list, removing all instances of this operand from the expr. + unsigned NumFound = 0; + for (unsigned j = 0, je = Ops.size(); j != je; ++j) { + if (Ops[j].Op != TheOp) continue; + ++NumFound; + Ops.erase(Ops.begin()+j); + --j; --je; + } + + /*DEBUG*/(errs() << "\nFACTORING [" << NumFound << "]: " << *TheOp << '\n'); + ++NumFactor; + + + // Insert a new multiply. + Value *Mul = ConstantInt::get(cast(I->getType()), NumFound); + Mul = BinaryOperator::CreateMul(TheOp, Mul, "factor", I); + + // Now that we have inserted a multiply, optimize it. This allows us to + // handle cases that require multiple factoring steps, such as this: + // (X*2) + (X*2) + (X*2) -> (X*2)*3 -> X*6 + Mul = ReassociateExpression(cast(Mul)); + + // If every add operand was a duplicate, return the multiply. + if (Ops.empty()) + return Mul; + + // Otherwise, we had some input that didn't have the dupe, such as + // "A + A + B" -> "A*2 + B". Add the new multiply to the list of + // things being added by this operation. + Ops.insert(Ops.begin(), ValueEntry(getRank(Mul), Mul)); + goto Restart; + } + // Check for X and -X in the operand list. - if (!BinaryOperator::isNeg(Ops[i].Op)) + if (!BinaryOperator::isNeg(TheOp)) continue; - Value *X = BinaryOperator::getNegArgument(Ops[i].Op); + Value *X = BinaryOperator::getNegArgument(TheOp); unsigned FoundX = FindInOperandList(Ops, i, X); if (FoundX == i) continue; @@ -639,7 +684,6 @@ Value *Reassociate::OptimizeAdd(Instruction *I, // Keep track of each multiply we see, to avoid triggering on (X*4)+(X*4) // where they are actually the same multiply. - SmallPtrSet Multiplies; unsigned MaxOcc = 0; Value *MaxOccVal = 0; for (unsigned i = 0, e = Ops.size(); i != e; ++i) { @@ -647,9 +691,6 @@ Value *Reassociate::OptimizeAdd(Instruction *I, if (BOp == 0 || BOp->getOpcode() != Instruction::Mul || !BOp->use_empty()) continue; - // If we've already seen this multiply, don't revisit it. - if (!Multiplies.insert(BOp)) continue; - // Compute all of the factors of this added value. SmallVector Factors; FindSingleUseMultiplyFactors(BOp, Factors); @@ -676,7 +717,7 @@ Value *Reassociate::OptimizeAdd(Instruction *I, // If any factor occurred more than one time, we can pull it out. if (MaxOcc > 1) { - DEBUG(errs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << "\n"); + DEBUG(errs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << '\n'); ++NumFactor; // Create a new instruction that uses the MaxOccVal twice. If we don't do @@ -698,13 +739,17 @@ Value *Reassociate::OptimizeAdd(Instruction *I, unsigned NumAddedValues = NewMulOps.size(); Value *V = EmitAddTreeOfValues(I, NewMulOps); - Value *V2 = BinaryOperator::CreateMul(V, MaxOccVal, "tmp", I); - // Now that we have inserted V and its sole use, optimize it. This allows - // us to handle cases that require multiple factoring steps, such as this: + // Now that we have inserted the add tree, optimize it. This allows us to + // handle cases that require multiple factoring steps, such as this: // A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C)) assert(NumAddedValues > 1 && "Each occurrence should contribute a value"); - ReassociateExpression(cast(V)); + V = ReassociateExpression(cast(V)); + + // Create the multiply. + Value *V2 = BinaryOperator::CreateMul(V, MaxOccVal, "tmp", I); + + // FIXME: Should rerun 'ReassociateExpression' on the mul too?? // If every add operand included the factor (e.g. "A*B + A*C"), then the // entire result expression is just the multiply "A*(B+C)". @@ -852,9 +897,10 @@ void Reassociate::ReassociateBB(BasicBlock *BB) { } } -void Reassociate::ReassociateExpression(BinaryOperator *I) { +Value *Reassociate::ReassociateExpression(BinaryOperator *I) { - // First, walk the expression tree, linearizing the tree, collecting + // First, walk the expression tree, linearizing the tree, collecting the + // operand information. SmallVector Ops; LinearizeExprTree(I, Ops); @@ -877,7 +923,7 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) { I->replaceAllUsesWith(V); RemoveDeadBinaryOp(I); ++NumAnnihil; - return; + return V; } // We want to sink immediates as deeply as possible except in the case where @@ -899,11 +945,13 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) { // eliminate it. I->replaceAllUsesWith(Ops[0].Op); RemoveDeadBinaryOp(I); - } else { - // Now that we ordered and optimized the expressions, splat them back into - // the expression tree, removing any unneeded nodes. - RewriteExprTree(I, Ops); + return Ops[0].Op; } + + // Now that we ordered and optimized the expressions, splat them back into + // the expression tree, removing any unneeded nodes. + RewriteExprTree(I, Ops); + return I; } diff --git a/test/Transforms/Reassociate/basictest.ll b/test/Transforms/Reassociate/basictest.ll index d96625419db..af08b583bb4 100644 --- a/test/Transforms/Reassociate/basictest.ll +++ b/test/Transforms/Reassociate/basictest.ll @@ -88,19 +88,19 @@ define void @test5() { } define i32 @test6() { - %tmp.0 = load i32* @a ; [#uses=2] - %tmp.1 = load i32* @b ; [#uses=2] + %tmp.0 = load i32* @a + %tmp.1 = load i32* @b ; (a+b) - %tmp.2 = add i32 %tmp.0, %tmp.1 ; [#uses=1] - %tmp.4 = load i32* @c ; [#uses=2] + %tmp.2 = add i32 %tmp.0, %tmp.1 + %tmp.4 = load i32* @c ; (a+b)+c - %tmp.5 = add i32 %tmp.2, %tmp.4 ; [#uses=1] + %tmp.5 = add i32 %tmp.2, %tmp.4 ; (a+c) - %tmp.8 = add i32 %tmp.0, %tmp.4 ; [#uses=1] + %tmp.8 = add i32 %tmp.0, %tmp.4 ; (a+c)+b - %tmp.11 = add i32 %tmp.8, %tmp.1 ; [#uses=1] + %tmp.11 = add i32 %tmp.8, %tmp.1 ; X ^ X = 0 - %RV = xor i32 %tmp.5, %tmp.11 ; [#uses=1] + %RV = xor i32 %tmp.5, %tmp.11 ret i32 %RV ; CHECK: @test6 ; CHECK: ret i32 0 @@ -108,6 +108,7 @@ define i32 @test6() { ; This should be one add and two multiplies. define i32 @test7(i32 %A, i32 %B, i32 %C) { + ; A*A*B + A*C*A %aa = mul i32 %A, %A %aab = mul i32 %aa, %B %ac = mul i32 %A, %C @@ -141,6 +142,27 @@ define i32 @test9(i32 %X) { %Z = add i32 %Y, %Y ret i32 %Z ; CHECK: @test9 -; CHECK-NEXT: %Z = mul i32 %X, 94 -; CHECK-NEXT: ret i32 %Z +; CHECK-NEXT: mul i32 %X, 94 +; CHECK-NEXT: ret i32 } + +define i32 @test10(i32 %X) { + %Y = add i32 %X ,%X + %Z = add i32 %Y, %X + ret i32 %Z +; CHECK: @test10 +; CHECK-NEXT: mul i32 %X, 3 +; CHECK-NEXT: ret i32 +} + +define i32 @test11(i32 %W) { + %X = mul i32 %W, 127 + %Y = add i32 %X ,%X + %Z = add i32 %Y, %X + ret i32 %Z +; CHECK: @test11 +; CHECK-NEXT: mul i32 %W, 381 +; CHECK-NEXT: ret i32 +} + +