diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index cd528c5a32a..ee6c2a39f3e 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -5988,6 +5988,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { ConstantFPSDNode *N0CFP = dyn_cast(N0); ConstantFPSDNode *N1CFP = dyn_cast(N1); EVT VT = N->getValueType(0); + DebugLoc dl = N->getDebugLoc(); if (N0CFP && N0CFP->isExactlyValue(1.0)) return DAG.getNode(ISD::FADD, N->getDebugLoc(), VT, N1, N2); @@ -5998,6 +5999,58 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { if (N0CFP && !N1CFP) return DAG.getNode(ISD::FMA, N->getDebugLoc(), VT, N1, N0, N2); + // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) + if (DAG.getTarget().Options.UnsafeFPMath && N1CFP && + N2.getOpcode() == ISD::FMUL && + N0 == N2.getOperand(0) && + N2.getOperand(1).getOpcode() == ISD::ConstantFP) { + return DAG.getNode(ISD::FMUL, dl, VT, N0, + DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1))); + } + + + // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) + if (DAG.getTarget().Options.UnsafeFPMath && + N0.getOpcode() == ISD::FMUL && N1CFP && + N0.getOperand(1).getOpcode() == ISD::ConstantFP) { + return DAG.getNode(ISD::FMA, dl, VT, + N0.getOperand(0), + DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1)), + N2); + } + + // (fma x, 1, y) -> (fadd x, y) + // (fma x, -1, y) -> (fadd (fneg x), y) + if (N1CFP) { + if (N1CFP->isExactlyValue(1.0)) + return DAG.getNode(ISD::FADD, dl, VT, N0, N2); + + if (N1CFP->isExactlyValue(-1.0) && + (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) { + SDValue RHSNeg = DAG.getNode(ISD::FNEG, dl, VT, N0); + AddToWorkList(RHSNeg.getNode()); + return DAG.getNode(ISD::FADD, dl, VT, N2, RHSNeg); + } + } + + // (fma x, c, x) -> (fmul x, (c+1)) + if (DAG.getTarget().Options.UnsafeFPMath && N1CFP && N0 == N2) { + return DAG.getNode(ISD::FMUL, dl, VT, + N0, + DAG.getNode(ISD::FADD, dl, VT, + N1, DAG.getConstantFP(1.0, VT))); + } + + // (fma x, c, (fneg x)) -> (fmul x, (c-1)) + if (DAG.getTarget().Options.UnsafeFPMath && N1CFP && + N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) { + return DAG.getNode(ISD::FMUL, dl, VT, + N0, + DAG.getNode(ISD::FADD, dl, VT, + N1, DAG.getConstantFP(-1.0, VT))); + } + + return SDValue(); } @@ -6367,6 +6420,17 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { } } + // (fneg (fmul c, x)) -> (fmul -c, x) + if (N0.getOpcode() == ISD::FMUL) { + ConstantFPSDNode *CFP1 = dyn_cast(N0.getOperand(1)); + if (CFP1) { + return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT, + N0.getOperand(0), + DAG.getNode(ISD::FNEG, N->getDebugLoc(), VT, + N0.getOperand(1))); + } + } + return SDValue(); } diff --git a/test/CodeGen/ARM/fp-fast.ll b/test/CodeGen/ARM/fp-fast.ll new file mode 100644 index 00000000000..ec571873817 --- /dev/null +++ b/test/CodeGen/ARM/fp-fast.ll @@ -0,0 +1,60 @@ +; RUN: llc -march=arm -mcpu=cortex-a9 -mattr=+vfp4 -enable-unsafe-fp-math < %s | FileCheck %s + +; CHECK: test1 +define float @test1(float %x) { +; CHECK-NOT: vfma +; CHECK: vmul.f32 +; CHECK-NOT: vfma + %t1 = fmul float %x, 3.0 + %t2 = call float @llvm.fma.f32(float %x, float 2.0, float %t1) + ret float %t2 +} + +; CHECK: test2 +define float @test2(float %x, float %y) { +; CHECK-NOT: vmul +; CHECK: vfma.f32 +; CHECK-NOT: vmul + %t1 = fmul float %x, 3.0 + %t2 = call float @llvm.fma.f32(float %t1, float 2.0, float %y) + ret float %t2 +} + +; CHECK: test3 +define float @test3(float %x, float %y) { +; CHECK-NOT: vfma +; CHECK: vadd.f32 +; CHECK-NOT: vfma + %t2 = call float @llvm.fma.f32(float %x, float 1.0, float %y) + ret float %t2 +} + +; CHECK: test4 +define float @test4(float %x, float %y) { +; CHECK-NOT: vfma +; CHECK: vsub.f32 +; CHECK-NOT: vfma + %t2 = call float @llvm.fma.f32(float %x, float -1.0, float %y) + ret float %t2 +} + +; CHECK: test5 +define float @test5(float %x) { +; CHECK-NOT: vfma +; CHECK: vmul.f32 +; CHECK-NOT: vfma + %t2 = call float @llvm.fma.f32(float %x, float 2.0, float %x) + ret float %t2 +} + +; CHECK: test6 +define float @test6(float %x) { +; CHECK-NOT: vfma +; CHECK: vmul.f32 +; CHECK-NOT: vfma + %t1 = fsub float -0.0, %x + %t2 = call float @llvm.fma.f32(float %x, float 5.0, float %t1) + ret float %t2 +} + +declare float @llvm.fma.f32(float, float, float)