From 6229219f7e7125b73bbc18b164a482e4178a8bed Mon Sep 17 00:00:00 2001 From: Chad Rosier Date: Mon, 23 Feb 2015 19:15:16 +0000 Subject: [PATCH] Prevent hoisting fmul from THEN/ELSE to IF if there is fmsub/fmadd opportunity. This patch adds the isProfitableToHoist API. For AArch64, we want to prevent a fmul from being hoisted in cases where it is more profitable to form a fmsub/fmadd. Phabricator Review: http://reviews.llvm.org/D7299 Patch by Lawrence Hu git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@230241 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Analysis/TargetTransformInfo.h | 8 +++ .../llvm/Analysis/TargetTransformInfoImpl.h | 2 + include/llvm/CodeGen/BasicTTIImpl.h | 4 ++ include/llvm/Target/TargetLowering.h | 2 + lib/Analysis/TargetTransformInfo.cpp | 4 ++ lib/Target/AArch64/AArch64ISelLowering.cpp | 28 ++++++++ lib/Target/AArch64/AArch64ISelLowering.h | 3 + lib/Transforms/Utils/SimplifyCFG.cpp | 8 ++- .../SimplifyCFG/AArch64/lit.local.cfg | 5 ++ .../SimplifyCFG/AArch64/prefer-fma.ll | 72 +++++++++++++++++++ 10 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 test/Transforms/SimplifyCFG/AArch64/lit.local.cfg create mode 100644 test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll diff --git a/include/llvm/Analysis/TargetTransformInfo.h b/include/llvm/Analysis/TargetTransformInfo.h index 26ceac189a1..49981416604 100644 --- a/include/llvm/Analysis/TargetTransformInfo.h +++ b/include/llvm/Analysis/TargetTransformInfo.h @@ -313,6 +313,10 @@ public: /// by referencing its sub-register AX. bool isTruncateFree(Type *Ty1, Type *Ty2) const; + /// \brief Return true if it is profitable to hoist instruction in the + /// then/else to before if. + bool isProfitableToHoist(Instruction *I) const; + /// \brief Return true if this type is legal. bool isTypeLegal(Type *Ty) const; @@ -521,6 +525,7 @@ public: int64_t BaseOffset, bool HasBaseReg, int64_t Scale) = 0; virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0; + virtual bool isProfitableToHoist(Instruction *I) = 0; virtual bool isTypeLegal(Type *Ty) = 0; virtual unsigned getJumpBufAlignment() = 0; virtual unsigned getJumpBufSize() = 0; @@ -633,6 +638,9 @@ public: bool isTruncateFree(Type *Ty1, Type *Ty2) override { return Impl.isTruncateFree(Ty1, Ty2); } + bool isProfitableToHoist(Instruction *I) override { + return Impl.isProfitableToHoist(I); + } bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); } unsigned getJumpBufAlignment() override { return Impl.getJumpBufAlignment(); } unsigned getJumpBufSize() override { return Impl.getJumpBufSize(); } diff --git a/include/llvm/Analysis/TargetTransformInfoImpl.h b/include/llvm/Analysis/TargetTransformInfoImpl.h index 0254880b4e6..3e02c0ce3ca 100644 --- a/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -225,6 +225,8 @@ public: bool isTruncateFree(Type *Ty1, Type *Ty2) { return false; } + bool isProfitableToHoist(Instruction *I) { return true; } + bool isTypeLegal(Type *Ty) { return false; } unsigned getJumpBufAlignment() { return 0; } diff --git a/include/llvm/CodeGen/BasicTTIImpl.h b/include/llvm/CodeGen/BasicTTIImpl.h index 25a74b331de..ff85b064bc9 100644 --- a/include/llvm/CodeGen/BasicTTIImpl.h +++ b/include/llvm/CodeGen/BasicTTIImpl.h @@ -145,6 +145,10 @@ public: return getTLI()->isTruncateFree(Ty1, Ty2); } + bool isProfitableToHoist(Instruction *I) { + return getTLI()->isProfitableToHoist(I); + } + bool isTypeLegal(Type *Ty) { EVT VT = getTLI()->getValueType(Ty); return getTLI()->isTypeLegal(VT); diff --git a/include/llvm/Target/TargetLowering.h b/include/llvm/Target/TargetLowering.h index d320bf1c30a..cd499ba5cb0 100644 --- a/include/llvm/Target/TargetLowering.h +++ b/include/llvm/Target/TargetLowering.h @@ -1456,6 +1456,8 @@ public: return false; } + virtual bool isProfitableToHoist(Instruction *I) const { return true; } + /// Return true if any actual instruction that defines a value of type Ty1 /// implicitly zero-extends the value to Ty2 in the result register. /// diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index b5440e2a2c3..7ff29b028ae 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -123,6 +123,10 @@ bool TargetTransformInfo::isTruncateFree(Type *Ty1, Type *Ty2) const { return TTIImpl->isTruncateFree(Ty1, Ty2); } +bool TargetTransformInfo::isProfitableToHoist(Instruction *I) const { + return TTIImpl->isProfitableToHoist(I); +} + bool TargetTransformInfo::isTypeLegal(Type *Ty) const { return TTIImpl->isTypeLegal(Ty); } diff --git a/lib/Target/AArch64/AArch64ISelLowering.cpp b/lib/Target/AArch64/AArch64ISelLowering.cpp index 332c8796c0a..fb31d7d3376 100644 --- a/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6533,6 +6533,34 @@ bool AArch64TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const { return NumBits1 > NumBits2; } +/// Check if it is profitable to hoist instruction in then/else to if. +/// Not profitable if I and it's user can form a FMA instruction +/// because we prefer FMSUB/FMADD. +bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const { + if (I->getOpcode() != Instruction::FMul) + return true; + + if (I->getNumUses() != 1) + return true; + + Instruction *User = I->user_back(); + + if (User && + !(User->getOpcode() == Instruction::FSub || + User->getOpcode() == Instruction::FAdd)) + return true; + + const TargetOptions &Options = getTargetMachine().Options; + EVT VT = getValueType(User->getOperand(0)->getType()); + + if (isFMAFasterThanFMulAndFAdd(VT) && + isOperationLegalOrCustom(ISD::FMA, VT) && + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath)) + return false; + + return true; +} + // All 32-bit GPR operations implicitly zero the high-half of the corresponding // 64-bit GPR. bool AArch64TargetLowering::isZExtFree(Type *Ty1, Type *Ty2) const { diff --git a/lib/Target/AArch64/AArch64ISelLowering.h b/lib/Target/AArch64/AArch64ISelLowering.h index 6cbc425e71f..db15538e43b 100644 --- a/lib/Target/AArch64/AArch64ISelLowering.h +++ b/lib/Target/AArch64/AArch64ISelLowering.h @@ -18,6 +18,7 @@ #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/IR/CallingConv.h" +#include "llvm/IR/Instruction.h" #include "llvm/Target/TargetLowering.h" namespace llvm { @@ -286,6 +287,8 @@ public: bool isTruncateFree(Type *Ty1, Type *Ty2) const override; bool isTruncateFree(EVT VT1, EVT VT2) const override; + bool isProfitableToHoist(Instruction *I) const override; + bool isZExtFree(Type *Ty1, Type *Ty2) const override; bool isZExtFree(EVT VT1, EVT VT2) const override; bool isZExtFree(SDValue Val, EVT VT2) const override; diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 9cbd05f897b..3248a83636c 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1053,7 +1053,8 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I); /// HoistThenElseCodeToIf - Given a conditional branch that goes to BB1 and /// BB2, hoist any common code in the two blocks up into the branch block. The /// caller of this function guarantees that BI's block dominates BB1 and BB2. -static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL) { +static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL, + const TargetTransformInfo &TTI) { // This does very trivial matching, with limited scanning, to find identical // instructions in the two blocks. In particular, we don't want to get into // O(M*N) situations here where M and N are the sizes of BB1 and BB2. As @@ -1088,6 +1089,9 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL) { if (isa(I1)) goto HoistTerminator; + if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2)) + return Changed; + // For a normal instruction, we just move one to right before the branch, // then replace all uses of the other with the first. Finally, we remove // the now redundant second instruction. @@ -4442,7 +4446,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // can hoist it up to the branching block. if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { - if (HoistThenElseCodeToIf(BI, DL)) + if (HoistThenElseCodeToIf(BI, DL, TTI)) return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } else { // If Successor #1 has multiple preds, we may be able to conditionally diff --git a/test/Transforms/SimplifyCFG/AArch64/lit.local.cfg b/test/Transforms/SimplifyCFG/AArch64/lit.local.cfg new file mode 100644 index 00000000000..6642d287068 --- /dev/null +++ b/test/Transforms/SimplifyCFG/AArch64/lit.local.cfg @@ -0,0 +1,5 @@ +config.suffixes = ['.ll'] + +targets = set(config.root.targets_to_build.split()) +if not 'AArch64' in targets: + config.unsupported = True diff --git a/test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll b/test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll new file mode 100644 index 00000000000..076cb5867af --- /dev/null +++ b/test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll @@ -0,0 +1,72 @@ +; RUN: opt < %s -mtriple=aarch64-linux-gnu -simplifycfg -enable-unsafe-fp-math -S >%t +; RUN: FileCheck %s < %t +; ModuleID = 't.cc' + +; Function Attrs: nounwind +define double @_Z3fooRdS_S_S_(double* dereferenceable(8) %x, double* dereferenceable(8) %y, double* dereferenceable(8) %a) #0 { +entry: + %0 = load double* %y, align 8 + %cmp = fcmp oeq double %0, 0.000000e+00 + %1 = load double* %x, align 8 + br i1 %cmp, label %if.then, label %if.else + +; fadd (const, (fmul x, y)) +if.then: ; preds = %entry +; CHECK-LABEL: if.then: +; CHECK: %3 = fmul fast double %1, %2 +; CHECK-NEXT: %mul = fadd fast double 1.000000e+00, %3 + %2 = load double* %a, align 8 + %3 = fmul fast double %1, %2 + %mul = fadd fast double 1.000000e+00, %3 + store double %mul, double* %y, align 8 + br label %if.end + +; fsub ((fmul x, y), z) +if.else: ; preds = %entry +; CHECK-LABEL: if.else: +; CHECK: %mul1 = fmul fast double %1, %2 +; CHECK-NEXT: %sub1 = fsub fast double %mul1, %0 + %4 = load double* %a, align 8 + %mul1 = fmul fast double %1, %4 + %sub1 = fsub fast double %mul1, %0 + store double %sub1, double* %y, align 8 + br label %if.end + +if.end: ; preds = %if.else, %if.then + %5 = load double* %y, align 8 + %cmp2 = fcmp oeq double %5, 2.000000e+00 + %6 = load double* %x, align 8 + br i1 %cmp2, label %if.then2, label %if.else2 + +; fsub (x, (fmul y, z)) +if.then2: ; preds = %entry +; CHECK-LABEL: if.then2: +; CHECK: %7 = fmul fast double %5, 3.000000e+00 +; CHECK-NEXT: %mul2 = fsub fast double %6, %7 + %7 = load double* %a, align 8 + %8 = fmul fast double %6, 3.0000000e+00 + %mul2 = fsub fast double %7, %8 + store double %mul2, double* %y, align 8 + br label %if.end2 + +; fsub (fneg((fmul x, y)), const) +if.else2: ; preds = %entry +; CHECK-LABEL: if.else2: +; CHECK: %mul3 = fmul fast double %5, 3.000000e+00 +; CHECK-NEXT: %neg = fsub fast double 0.000000e+00, %mul3 +; CHECK-NEXT: %sub2 = fsub fast double %neg, 3.000000e+00 + %mul3 = fmul fast double %6, 3.0000000e+00 + %neg = fsub fast double 0.0000000e+00, %mul3 + %sub2 = fsub fast double %neg, 3.0000000e+00 + store double %sub2, double* %y, align 8 + br label %if.end2 + +if.end2: ; preds = %if.else, %if.then + %9 = load double* %x, align 8 + %10 = load double* %y, align 8 + %add = fadd fast double %9, %10 + %11 = load double* %a, align 8 + %add2 = fadd fast double %add, %11 + ret double %add2 +} +