fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)

If a square root call has an FP multiplication argument that can be reassociated,
then we can hoist a repeated factor out of the square root call and into a fabs().

In the simplest case, this:

   y = sqrt(x * x);

becomes this:

   y = fabs(x);

This patch relies on an earlier optimization in instcombine or reassociate to put the
multiplication tree into a canonical form, so we don't have to search over
every permutation of the multiplication tree.

Because there are no IR-level FastMathFlags for intrinsics (PR21290), we have to
use function-level attributes to do this optimization. This needs to be fixed
for both the intrinsics and in the backend.

Differential Revision: http://reviews.llvm.org/D5787



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@219944 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Sanjay Patel
2014-10-16 18:48:17 +00:00
parent c40dab2069
commit d8214db086
3 changed files with 258 additions and 1 deletions

View File

@ -27,12 +27,14 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Target/TargetLibraryInfo.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
using namespace llvm;
using namespace PatternMatch;
static cl::opt<bool>
ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden,
@ -1254,6 +1256,85 @@ Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) {
return Ret;
}
Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
if (UnsafeFPShrink && Callee->getName() == "sqrt" &&
TLI->has(LibFunc::sqrtf)) {
Ret = optimizeUnaryDoubleFP(CI, B, true);
}
// FIXME: For finer-grain optimization, we need intrinsics to have the same
// fast-math flag decorations that are applied to FP instructions. For now,
// we have to rely on the function-level unsafe-fp-math attribute to do this
// optimization because there's no other way to express that the sqrt can be
// reassociated.
Function *F = CI->getParent()->getParent();
if (F->hasFnAttribute("unsafe-fp-math")) {
// Check for unsafe-fp-math = true.
Attribute Attr = F->getFnAttribute("unsafe-fp-math");
if (Attr.getValueAsString() != "true")
return Ret;
}
Value *Op = CI->getArgOperand(0);
if (Instruction *I = dyn_cast<Instruction>(Op)) {
if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) {
// We're looking for a repeated factor in a multiplication tree,
// so we can do this fold: sqrt(x * x) -> fabs(x);
// or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y).
Value *Op0 = I->getOperand(0);
Value *Op1 = I->getOperand(1);
Value *RepeatOp = nullptr;
Value *OtherOp = nullptr;
if (Op0 == Op1) {
// Simple match: the operands of the multiply are identical.
RepeatOp = Op0;
} else {
// Look for a more complicated pattern: one of the operands is itself
// a multiply, so search for a common factor in that multiply.
// Note: We don't bother looking any deeper than this first level or for
// variations of this pattern because instcombine's visitFMUL and/or the
// reassociation pass should give us this form.
Value *OtherMul0, *OtherMul1;
if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
// Pattern: sqrt((x * y) * z)
if (OtherMul0 == OtherMul1) {
// Matched: sqrt((x * x) * z)
RepeatOp = OtherMul0;
OtherOp = Op1;
}
}
}
if (RepeatOp) {
// Fast math flags for any created instructions should match the sqrt
// and multiply.
// FIXME: We're not checking the sqrt because it doesn't have
// fast-math-flags (see earlier comment).
IRBuilder<true, ConstantFolder,
IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B);
B.SetFastMathFlags(I->getFastMathFlags());
// If we found a repeated factor, hoist it out of the square root and
// replace it with the fabs of that factor.
Module *M = Callee->getParent();
Type *ArgType = Op->getType();
Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
if (OtherOp) {
// If we found a non-repeated factor, we still need to get its square
// root. We then multiply that by the value that was simplified out
// of the square root calculation.
Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
return B.CreateFMul(FabsCall, SqrtCall);
}
return FabsCall;
}
}
}
return Ret;
}
static bool isTrigLibCall(CallInst *CI);
static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
bool UseFloat, Value *&Sin, Value *&Cos,
@ -1919,6 +2000,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
return optimizeExp2(CI, Builder);
case Intrinsic::fabs:
return optimizeFabs(CI, Builder);
case Intrinsic::sqrt:
return optimizeSqrt(CI, Builder);
default:
return nullptr;
}
@ -1995,6 +2078,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
case LibFunc::fabs:
case LibFunc::fabsl:
return optimizeFabs(CI, Builder);
case LibFunc::sqrtf:
case LibFunc::sqrt:
case LibFunc::sqrtl:
return optimizeSqrt(CI, Builder);
case LibFunc::ffs:
case LibFunc::ffsl:
case LibFunc::ffsll:
@ -2055,7 +2142,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
case LibFunc::logb:
case LibFunc::sin:
case LibFunc::sinh:
case LibFunc::sqrt:
case LibFunc::tan:
case LibFunc::tanh:
if (UnsafeFPShrink && hasFloatVersion(FuncName))