From a0c9939873c404f272b3e0abb102c335146764fe Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Thu, 14 Mar 2013 18:08:26 +0000 Subject: [PATCH] Perform factorization as a last resort of unsafe fadd/fsub simplification. Rules include: 1)1 x*y +/- x*z => x*(y +/- z) (the order of operands dosen't matter) 2) y/x +/- z/x => (y +/- z)/x The transformation is disabled if the new add/sub expr "y +/- z" is a denormal/naz/inifinity. rdar://12911472 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@177088 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineAddSub.cpp | 96 +++++++++++++++- test/Transforms/InstCombine/fast-math.ll | 105 ++++++++++++++++++ 2 files changed, 196 insertions(+), 5 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index c6d60d6f008..3c5781ca73e 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -150,7 +150,9 @@ namespace { typedef SmallVector AddendVect; Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota); - + + Value *performFactorization(Instruction *I); + /// Convert given addend to a Value Value *createAddendVal(const FAddend &A, bool& NeedNeg); @@ -159,6 +161,7 @@ namespace { Value *createFSub(Value *Opnd0, Value *Opnd1); Value *createFAdd(Value *Opnd0, Value *Opnd1); Value *createFMul(Value *Opnd0, Value *Opnd1); + Value *createFDiv(Value *Opnd0, Value *Opnd1); Value *createFNeg(Value *V); Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota); void createInstPostProc(Instruction *NewInst); @@ -388,6 +391,78 @@ unsigned FAddend::drillAddendDownOneStep return BreakNum; } +// Try to perform following optimization on the input instruction I. Return the +// simplified expression if was successful; otherwise, return 0. +// +// Instruction "I" is Simplified into +// ------------------------------------------------------- +// (x * y) +/- (x * z) x * (y +/- z) +// (y / x) +/- (z / x) (y +/- z) / x +// +Value *FAddCombine::performFactorization(Instruction *I) { + assert((I->getOpcode() == Instruction::FAdd || + I->getOpcode() == Instruction::FSub) && "Expect add/sub"); + + Instruction *I0 = dyn_cast(I->getOperand(0)); + Instruction *I1 = dyn_cast(I->getOperand(1)); + + if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode()) + return 0; + + bool isMpy = false; + if (I0->getOpcode() == Instruction::FMul) + isMpy = true; + else if (I0->getOpcode() != Instruction::FDiv) + return 0; + + Value *Opnd0_0 = I0->getOperand(0); + Value *Opnd0_1 = I0->getOperand(1); + Value *Opnd1_0 = I1->getOperand(0); + Value *Opnd1_1 = I1->getOperand(1); + + // Input Instr I Factor AddSub0 AddSub1 + // ---------------------------------------------- + // (x*y) +/- (x*z) x y z + // (y/x) +/- (z/x) x y z + // + Value *Factor = 0; + Value *AddSub0 = 0, *AddSub1 = 0; + + if (isMpy) { + if (Opnd0_0 == Opnd1_0 || Opnd0_0 == Opnd1_1) + Factor = Opnd0_0; + else if (Opnd0_1 == Opnd1_0 || Opnd0_1 == Opnd1_1) + Factor = Opnd0_1; + + if (Factor) { + AddSub0 = (Factor == Opnd0_0) ? Opnd0_1 : Opnd0_0; + AddSub1 = (Factor == Opnd1_0) ? Opnd1_1 : Opnd1_0; + } + } else if (Opnd0_1 == Opnd1_1) { + Factor = Opnd0_1; + AddSub0 = Opnd0_0; + AddSub1 = Opnd1_0; + } + + if (!Factor) + return 0; + + // Create expression "NewAddSub = AddSub0 +/- AddsSub1" + Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ? + createFAdd(AddSub0, AddSub1) : + createFSub(AddSub0, AddSub1); + if (ConstantFP *CFP = dyn_cast(NewAddSub)) { + const APFloat &F = CFP->getValueAPF(); + if (!F.isNormal() || F.isDenormal()) + return 0; + } + + if (isMpy) + return createFMul(Factor, NewAddSub); + + return createFDiv(NewAddSub, Factor); +} + Value *FAddCombine::simplify(Instruction *I) { assert(I->hasUnsafeAlgebra() && "Should be in unsafe mode"); @@ -471,7 +546,8 @@ Value *FAddCombine::simplify(Instruction *I) { return R; } - return 0; + // step 6: Try factorization as the last resort, + return performFactorization(I); } Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { @@ -627,7 +703,8 @@ Value *FAddCombine::createNaryFAdd Value *FAddCombine::createFSub (Value *Opnd0, Value *Opnd1) { Value *V = Builder->CreateFSub(Opnd0, Opnd1); - createInstPostProc(cast(V)); + if (Instruction *I = dyn_cast(V)) + createInstPostProc(I); return V; } @@ -639,13 +716,22 @@ Value *FAddCombine::createFNeg(Value *V) { Value *FAddCombine::createFAdd (Value *Opnd0, Value *Opnd1) { Value *V = Builder->CreateFAdd(Opnd0, Opnd1); - createInstPostProc(cast(V)); + if (Instruction *I = dyn_cast(V)) + createInstPostProc(I); return V; } Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) { Value *V = Builder->CreateFMul(Opnd0, Opnd1); - createInstPostProc(cast(V)); + if (Instruction *I = dyn_cast(V)) + createInstPostProc(I); + return V; +} + +Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) { + Value *V = Builder->CreateFDiv(Opnd0, Opnd1); + if (Instruction *I = dyn_cast(V)) + createInstPostProc(I); return V; } diff --git a/test/Transforms/InstCombine/fast-math.ll b/test/Transforms/InstCombine/fast-math.ll index 3e32a2e4dd4..47f1ec48046 100644 --- a/test/Transforms/InstCombine/fast-math.ll +++ b/test/Transforms/InstCombine/fast-math.ll @@ -350,3 +350,108 @@ define float @fdiv9(float %x) { ; CHECK: @fdiv9 ; CHECK: fmul fast float %x, 5.000000e+00 } + +; ========================================================================= +; +; Testing-cases about factorization +; +; ========================================================================= +; x*z + y*z => (x+y) * z +define float @fact_mul1(float %x, float %y, float %z) { + %t1 = fmul fast float %x, %z + %t2 = fmul fast float %y, %z + %t3 = fadd fast float %t1, %t2 + ret float %t3 +; CHECK: @fact_mul1 +; CHECK: fmul fast float %1, %z +} + +; z*x + y*z => (x+y) * z +define float @fact_mul2(float %x, float %y, float %z) { + %t1 = fmul fast float %z, %x + %t2 = fmul fast float %y, %z + %t3 = fsub fast float %t1, %t2 + ret float %t3 +; CHECK: @fact_mul2 +; CHECK: fmul fast float %1, %z +} + +; z*x - z*y => (x-y) * z +define float @fact_mul3(float %x, float %y, float %z) { + %t2 = fmul fast float %z, %y + %t1 = fmul fast float %z, %x + %t3 = fsub fast float %t1, %t2 + ret float %t3 +; CHECK: @fact_mul3 +; CHECK: fmul fast float %1, %z +} + +; x*z - z*y => (x-y) * z +define float @fact_mul4(float %x, float %y, float %z) { + %t1 = fmul fast float %x, %z + %t2 = fmul fast float %z, %y + %t3 = fsub fast float %t1, %t2 + ret float %t3 +; CHECK: @fact_mul4 +; CHECK: fmul fast float %1, %z +} + +; x/y + x/z, no xform +define float @fact_div1(float %x, float %y, float %z) { + %t1 = fdiv fast float %x, %y + %t2 = fdiv fast float %x, %z + %t3 = fadd fast float %t1, %t2 + ret float %t3 +; CHECK: fact_div1 +; CHECK: fadd fast float %t1, %t2 +} + +; x/y + z/x; no xform +define float @fact_div2(float %x, float %y, float %z) { + %t1 = fdiv fast float %x, %y + %t2 = fdiv fast float %z, %x + %t3 = fadd fast float %t1, %t2 + ret float %t3 +; CHECK: fact_div2 +; CHECK: fadd fast float %t1, %t2 +} + +; y/x + z/x => (y+z)/x +define float @fact_div3(float %x, float %y, float %z) { + %t1 = fdiv fast float %y, %x + %t2 = fdiv fast float %z, %x + %t3 = fadd fast float %t1, %t2 + ret float %t3 +; CHECK: fact_div3 +; CHECK: fdiv fast float %1, %x +} + +; y/x - z/x => (y-z)/x +define float @fact_div4(float %x, float %y, float %z) { + %t1 = fdiv fast float %y, %x + %t2 = fdiv fast float %z, %x + %t3 = fsub fast float %t1, %t2 + ret float %t3 +; CHECK: fact_div4 +; CHECK: fdiv fast float %1, %x +} + +; y/x - z/x => (y-z)/x is disabled if y-z is denormal. +define float @fact_div5(float %x) { + %t1 = fdiv fast float 0x3810000000000000, %x + %t2 = fdiv fast float 0x3800000000000000, %x + %t3 = fadd fast float %t1, %t2 + ret float %t3 +; CHECK: fact_div5 +; CHECK: fdiv fast float 0x3818000000000000, %x +} + +; y/x - z/x => (y-z)/x is disabled if y-z is denormal. +define float @fact_div6(float %x) { + %t1 = fdiv fast float 0x3810000000000000, %x + %t2 = fdiv fast float 0x3800000000000000, %x + %t3 = fsub fast float %t1, %t2 + ret float %t3 +; CHECK: fact_div6 +; CHECK: %t3 = fsub fast float %t1, %t2 +}