diff --git a/include/llvm/CodeGen/ISDOpcodes.h b/include/llvm/CodeGen/ISDOpcodes.h index 2d5d6c3393e..2d1c8cd6fdd 100644 --- a/include/llvm/CodeGen/ISDOpcodes.h +++ b/include/llvm/CodeGen/ISDOpcodes.h @@ -229,7 +229,14 @@ namespace ISD { SMULO, UMULO, /// Simple binary floating point operators. - FADD, FSUB, FMUL, FMA, FDIV, FREM, + FADD, FSUB, FMUL, FDIV, FREM, + + /// FMA - Perform a * b + c with no intermediate rounding step. + FMA, + + /// FMAD - Perform a * b + c, while getting the same result as the + /// separately rounded operations. + FMAD, /// FCOPYSIGN(X, Y) - Return the value of X with the sign of Y. NOTE: This /// DAG node does not require that X and Y have the same type, just that the diff --git a/include/llvm/Target/TargetSelectionDAG.td b/include/llvm/Target/TargetSelectionDAG.td index 744a7b0c97a..2ecd900d34b 100644 --- a/include/llvm/Target/TargetSelectionDAG.td +++ b/include/llvm/Target/TargetSelectionDAG.td @@ -381,6 +381,7 @@ def fmul : SDNode<"ISD::FMUL" , SDTFPBinOp, [SDNPCommutative]>; def fdiv : SDNode<"ISD::FDIV" , SDTFPBinOp>; def frem : SDNode<"ISD::FREM" , SDTFPBinOp>; def fma : SDNode<"ISD::FMA" , SDTFPTernaryOp>; +def fmad : SDNode<"ISD::FMAD" , SDTFPTernaryOp>; def fabs : SDNode<"ISD::FABS" , SDTFPUnaryOp>; def fminnum : SDNode<"ISD::FMINNUM" , SDTFPBinOp>; def fmaxnum : SDNode<"ISD::FMAXNUM" , SDTFPBinOp>; diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index e8d1acf60ab..53867b57bc3 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6938,6 +6938,133 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(BV), VT, Ops); } +// Attempt different variants of (fadd (fmul a, b), c) -> fma or fmad +static SDValue performFaddFmulCombines(unsigned FusedOpcode, + bool Aggressive, + SDNode *N, + const TargetLowering &TLI, + SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + + // fold (fadd (fmul x, y), z) -> (fma x, y, z) + if (N0.getOpcode() == ISD::FMUL && + (Aggressive || N0->hasOneUse())) { + return DAG.getNode(FusedOpcode, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1), N1); + } + + // fold (fadd x, (fmul y, z)) -> (fma y, z, x) + // Note: Commutes FADD operands. + if (N1.getOpcode() == ISD::FMUL && + (Aggressive || N1->hasOneUse())) { + return DAG.getNode(FusedOpcode, SDLoc(N), VT, + N1.getOperand(0), N1.getOperand(1), N0); + } + + // More folding opportunities when target permits. + if (Aggressive) { + // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) + if (N0.getOpcode() == ISD::FMA && + N0.getOperand(2).getOpcode() == ISD::FMUL) { + return DAG.getNode(FusedOpcode, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1), + DAG.getNode(FusedOpcode, SDLoc(N), VT, + N0.getOperand(2).getOperand(0), + N0.getOperand(2).getOperand(1), + N1)); + } + + // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) + if (N1->getOpcode() == ISD::FMA && + N1.getOperand(2).getOpcode() == ISD::FMUL) { + return DAG.getNode(FusedOpcode, SDLoc(N), VT, + N1.getOperand(0), N1.getOperand(1), + DAG.getNode(FusedOpcode, SDLoc(N), VT, + N1.getOperand(2).getOperand(0), + N1.getOperand(2).getOperand(1), + N0)); + } + } + + return SDValue(); +} + +static SDValue performFsubFmulCombines(unsigned FusedOpcode, + bool Aggressive, + SDNode *N, + const TargetLowering &TLI, + SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + + SDLoc SL(N); + + // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) + if (N0.getOpcode() == ISD::FMUL && + (Aggressive || N0->hasOneUse())) { + return DAG.getNode(FusedOpcode, SL, VT, + N0.getOperand(0), N0.getOperand(1), + DAG.getNode(ISD::FNEG, SL, VT, N1)); + } + + // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x) + // Note: Commutes FSUB operands. + if (N1.getOpcode() == ISD::FMUL && + (Aggressive || N1->hasOneUse())) + return DAG.getNode(FusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, + N1.getOperand(0)), + N1.getOperand(1), N0); + + // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z)) + if (N0.getOpcode() == ISD::FNEG && + N0.getOperand(0).getOpcode() == ISD::FMUL && + (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) { + SDValue N00 = N0.getOperand(0).getOperand(0); + SDValue N01 = N0.getOperand(0).getOperand(1); + return DAG.getNode(FusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, N00), N01, + DAG.getNode(ISD::FNEG, SL, VT, N1)); + } + + // More folding opportunities when target permits. + if (Aggressive) { + // fold (fsub (fma x, y, (fmul u, v)), z) + // -> (fma x, y (fma u, v, (fneg z))) + if (N0.getOpcode() == FusedOpcode && + N0.getOperand(2).getOpcode() == ISD::FMUL) { + return DAG.getNode(FusedOpcode, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1), + DAG.getNode(FusedOpcode, SDLoc(N), VT, + N0.getOperand(2).getOperand(0), + N0.getOperand(2).getOperand(1), + DAG.getNode(ISD::FNEG, SDLoc(N), VT, + N1))); + } + + // fold (fsub x, (fma y, z, (fmul u, v))) + // -> (fma (fneg y), z, (fma (fneg u), v, x)) + if (N1.getOpcode() == FusedOpcode && + N1.getOperand(2).getOpcode() == ISD::FMUL) { + SDValue N20 = N1.getOperand(2).getOperand(0); + SDValue N21 = N1.getOperand(2).getOperand(1); + return DAG.getNode(FusedOpcode, SDLoc(N), VT, + DAG.getNode(ISD::FNEG, SDLoc(N), VT, + N1.getOperand(0)), + N1.getOperand(1), + DAG.getNode(FusedOpcode, SDLoc(N), VT, + DAG.getNode(ISD::FNEG, SDLoc(N), VT, + N20), + N21, N0)); + } + } + + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -7077,23 +7204,27 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } } // enable-unsafe-fp-math + if (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)) { + // Assume if there is an fmad instruction that it should be aggressively + // used. + if (SDValue Fused = performFaddFmulCombines(ISD::FMAD, true, N, TLI, DAG)) + return Fused; + } + // FADD -> FMA combines: if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) && TLI.isFMAFasterThanFMulAndFAdd(VT) && (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT))) { - // fold (fadd (fmul x, y), z) -> (fma x, y, z) - if (N0.getOpcode() == ISD::FMUL && - (N0->hasOneUse() || TLI.enableAggressiveFMAFusion(VT))) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - N0.getOperand(0), N0.getOperand(1), N1); - - // fold (fadd x, (fmul y, z)) -> (fma y, z, x) - // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FMUL && - (N1->hasOneUse() || TLI.enableAggressiveFMAFusion(VT))) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - N1.getOperand(0), N1.getOperand(1), N0); + if (!TLI.isOperationLegal(ISD::FMAD, VT)) { + // Don't form FMA if we are preferring FMAD. + if (SDValue Fused + = performFaddFmulCombines(ISD::FMA, + TLI.enableAggressiveFMAFusion(VT), + N, TLI, DAG)) { + return Fused; + } + } // When FP_EXTEND nodes are free on the target, and there is an opportunity // to combine into FMA, arrange such nodes accordingly. @@ -7122,30 +7253,6 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { N10.getOperand(1)), N0); } } - - // More folding opportunities when target permits. - if (TLI.enableAggressiveFMAFusion(VT)) { - - // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) - if (N0.getOpcode() == ISD::FMA && - N0.getOperand(2).getOpcode() == ISD::FMUL) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - N0.getOperand(0), N0.getOperand(1), - DAG.getNode(ISD::FMA, SDLoc(N), VT, - N0.getOperand(2).getOperand(0), - N0.getOperand(2).getOperand(1), - N1)); - - // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) - if (N1->getOpcode() == ISD::FMA && - N1.getOperand(2).getOpcode() == ISD::FMUL) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - N1.getOperand(0), N1.getOperand(1), - DAG.getNode(ISD::FMA, SDLoc(N), VT, - N1.getOperand(2).getOperand(0), - N1.getOperand(2).getOperand(1), - N0)); - } } return SDValue(); @@ -7207,43 +7314,32 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { } } + if (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)) { + // Assume if there is an fmad instruction that it should be aggressively + // used. + if (SDValue Fused = performFsubFmulCombines(ISD::FMAD, true, N, TLI, DAG)) + return Fused; + } + // FSUB -> FMA combines: if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) && TLI.isFMAFasterThanFMulAndFAdd(VT) && (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT))) { - // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) - if (N0.getOpcode() == ISD::FMUL && - (N0->hasOneUse() || TLI.enableAggressiveFMAFusion(VT))) - return DAG.getNode(ISD::FMA, dl, VT, - N0.getOperand(0), N0.getOperand(1), - DAG.getNode(ISD::FNEG, dl, VT, N1)); + if (!TLI.isOperationLegal(ISD::FMAD, VT)) { + // Don't form FMA if we are preferring FMAD. - // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x) - // Note: Commutes FSUB operands. - if (N1.getOpcode() == ISD::FMUL && - (N1->hasOneUse() || TLI.enableAggressiveFMAFusion(VT))) - return DAG.getNode(ISD::FMA, dl, VT, - DAG.getNode(ISD::FNEG, dl, VT, - N1.getOperand(0)), - N1.getOperand(1), N0); - - // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z)) - if (N0.getOpcode() == ISD::FNEG && - N0.getOperand(0).getOpcode() == ISD::FMUL && - ((N0->hasOneUse() && N0.getOperand(0).hasOneUse()) || - TLI.enableAggressiveFMAFusion(VT))) { - SDValue N00 = N0.getOperand(0).getOperand(0); - SDValue N01 = N0.getOperand(0).getOperand(1); - return DAG.getNode(ISD::FMA, dl, VT, - DAG.getNode(ISD::FNEG, dl, VT, N00), N01, - DAG.getNode(ISD::FNEG, dl, VT, N1)); + if (SDValue Fused + = performFsubFmulCombines(ISD::FMA, + TLI.enableAggressiveFMAFusion(VT), + N, TLI, DAG)) { + return Fused; + } } // When FP_EXTEND nodes are free on the target, and there is an opportunity // to combine into FMA, arrange such nodes accordingly. if (TLI.isFPExtFree(VT)) { - // fold (fsub (fpext (fmul x, y)), z) // -> (fma (fpext x), (fpext y), (fneg z)) if (N0.getOpcode() == ISD::FP_EXTEND) { @@ -7308,38 +7404,6 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { } } } - - // More folding opportunities when target permits. - if (TLI.enableAggressiveFMAFusion(VT)) { - - // fold (fsub (fma x, y, (fmul u, v)), z) - // -> (fma x, y (fma u, v, (fneg z))) - if (N0.getOpcode() == ISD::FMA && - N0.getOperand(2).getOpcode() == ISD::FMUL) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - N0.getOperand(0), N0.getOperand(1), - DAG.getNode(ISD::FMA, SDLoc(N), VT, - N0.getOperand(2).getOperand(0), - N0.getOperand(2).getOperand(1), - DAG.getNode(ISD::FNEG, SDLoc(N), VT, - N1))); - - // fold (fsub x, (fma y, z, (fmul u, v))) - // -> (fma (fneg y), z, (fma (fneg u), v, x)) - if (N1.getOpcode() == ISD::FMA && - N1.getOperand(2).getOpcode() == ISD::FMUL) { - SDValue N20 = N1.getOperand(2).getOperand(0); - SDValue N21 = N1.getOperand(2).getOperand(1); - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - DAG.getNode(ISD::FNEG, SDLoc(N), VT, - N1.getOperand(0)), - N1.getOperand(1), - DAG.getNode(ISD::FMA, SDLoc(N), VT, - DAG.getNode(ISD::FNEG, SDLoc(N), VT, - N20), - N21, N0)); - } - } } return SDValue(); diff --git a/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index e5473e35cae..ed337eb9648 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -3519,6 +3519,9 @@ void SelectionDAGLegalize::ExpandNode(SDNode *Node) { RTLIB::FMA_F80, RTLIB::FMA_F128, RTLIB::FMA_PPCF128)); break; + case ISD::FMAD: + llvm_unreachable("Illegal fmad should never be formed"); + case ISD::FADD: Results.push_back(ExpandFPLibCall(Node, RTLIB::ADD_F32, RTLIB::ADD_F64, RTLIB::ADD_F80, RTLIB::ADD_F128, diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index e8577d898c2..17eff944c62 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -187,6 +187,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const { case ISD::FMUL: return "fmul"; case ISD::FDIV: return "fdiv"; case ISD::FMA: return "fma"; + case ISD::FMAD: return "fmad"; case ISD::FREM: return "frem"; case ISD::FCOPYSIGN: return "fcopysign"; case ISD::FGETSIGN: return "fgetsign"; diff --git a/lib/CodeGen/TargetLoweringBase.cpp b/lib/CodeGen/TargetLoweringBase.cpp index 630c3313772..459969b58b9 100644 --- a/lib/CodeGen/TargetLoweringBase.cpp +++ b/lib/CodeGen/TargetLoweringBase.cpp @@ -765,6 +765,7 @@ void TargetLoweringBase::initActions() { setOperationAction(ISD::CONCAT_VECTORS, VT, Expand); setOperationAction(ISD::FMINNUM, VT, Expand); setOperationAction(ISD::FMAXNUM, VT, Expand); + setOperationAction(ISD::FMAD, VT, Expand); // These library functions default to expand. setOperationAction(ISD::FROUND, VT, Expand);