diff --git a/include/llvm/IR/Operator.h b/include/llvm/IR/Operator.h index 119156a50e0..a30a84685d8 100644 --- a/include/llvm/IR/Operator.h +++ b/include/llvm/IR/Operator.h @@ -210,6 +210,10 @@ public: setNoSignedZeros(); setAllowReciprocal(); } + + void operator&=(const FastMathFlags &OtherFlags) { + Flags &= OtherFlags.Flags; + } }; diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 71990a27ac0..29ab6c0623e 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1232,7 +1232,10 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { LHSOrig = Builder->CreateFPExt(LHSOrig, CI.getType()); if (RHSOrig->getType() != CI.getType()) RHSOrig = Builder->CreateFPExt(RHSOrig, CI.getType()); - return BinaryOperator::Create(OpI->getOpcode(), LHSOrig, RHSOrig); + Instruction *RI = + BinaryOperator::Create(OpI->getOpcode(), LHSOrig, RHSOrig); + RI->copyFastMathFlags(OpI); + return RI; } break; case Instruction::FMul: @@ -1246,7 +1249,10 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { LHSOrig = Builder->CreateFPExt(LHSOrig, CI.getType()); if (RHSOrig->getType() != CI.getType()) RHSOrig = Builder->CreateFPExt(RHSOrig, CI.getType()); - return BinaryOperator::CreateFMul(LHSOrig, RHSOrig); + Instruction *RI = + BinaryOperator::CreateFMul(LHSOrig, RHSOrig); + RI->copyFastMathFlags(OpI); + return RI; } break; case Instruction::FDiv: @@ -1261,7 +1267,10 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { LHSOrig = Builder->CreateFPExt(LHSOrig, CI.getType()); if (RHSOrig->getType() != CI.getType()) RHSOrig = Builder->CreateFPExt(RHSOrig, CI.getType()); - return BinaryOperator::CreateFDiv(LHSOrig, RHSOrig); + Instruction *RI = + BinaryOperator::CreateFDiv(LHSOrig, RHSOrig); + RI->copyFastMathFlags(OpI); + return RI; } break; case Instruction::FRem: @@ -1274,6 +1283,8 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { else if (RHSWidth <= SrcWidth) RHSOrig = Builder->CreateFPExt(RHSOrig, LHSOrig->getType()); Value *ExactResult = Builder->CreateFRem(LHSOrig, RHSOrig); + if (Instruction *RI = dyn_cast(ExactResult)) + RI->copyFastMathFlags(OpI); return CastInst::CreateFPCast(ExactResult, CI.getType()); } @@ -1281,7 +1292,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { if (BinaryOperator::isFNeg(OpI)) { Value *InnerTrunc = Builder->CreateFPTrunc(OpI->getOperand(1), CI.getType()); - return BinaryOperator::CreateFNeg(InnerTrunc); + Instruction *RI = BinaryOperator::CreateFNeg(InnerTrunc); + RI->copyFastMathFlags(OpI); + return RI; } } diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 283bec2881f..555ffc77523 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -901,6 +901,11 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NegVal; // Compute -Z if (SI.getType()->isFPOrFPVectorTy()) { NegVal = Builder->CreateFNeg(SubOp->getOperand(1)); + if (Instruction *NegInst = dyn_cast(NegVal)) { + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + NegInst->setFastMathFlags(Flags); + } } else { NegVal = Builder->CreateNeg(SubOp->getOperand(1)); } @@ -913,9 +918,15 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Builder->CreateSelect(CondVal, NewTrueOp, NewFalseOp, SI.getName() + ".p"); - if (SI.getType()->isFPOrFPVectorTy()) - return BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); - else + if (SI.getType()->isFPOrFPVectorTy()) { + Instruction *RI = + BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); + + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + RI->setFastMathFlags(Flags); + return RI; + } else return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); } } diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index c0b9b2fc3e5..178be61b43e 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -638,6 +638,8 @@ static Value *BuildNew(Instruction *I, ArrayRef NewOps) { if (isa(BO)) { New->setIsExact(BO->isExact()); } + if (isa(BO)) + New->copyFastMathFlags(I); return New; } case Instruction::ICmp: diff --git a/test/Transforms/InstCombine/fast-math.ll b/test/Transforms/InstCombine/fast-math.ll index 0371488dfd8..de51c494685 100644 --- a/test/Transforms/InstCombine/fast-math.ll +++ b/test/Transforms/InstCombine/fast-math.ll @@ -160,6 +160,22 @@ define float @fold15(float %x, float %y) { ; CHECK: ret } +; (select X+Y, X-Y) => X + (select Y, -Y) +define float @fold16(float %x, float %y) { + %cmp = fcmp ogt float %x, %y + %plus = fadd fast float %x, %y + %minus = fsub fast float %x, %y + %r = select i1 %cmp, float %plus, float %minus + ret float %r +; CHECK: fold16 +; CHECK: fsub fast float +; CHECK: select +; CHECK: fadd fast float +; CHECK: ret +} + + + ; ========================================================================= ; ; Testing-cases about fmul begin diff --git a/test/Transforms/InstCombine/fpcast.ll b/test/Transforms/InstCombine/fpcast.ll index 05d1b48d599..9be66fd42c6 100644 --- a/test/Transforms/InstCombine/fpcast.ll +++ b/test/Transforms/InstCombine/fpcast.ll @@ -31,6 +31,15 @@ define half @test4(float %a) { ret half %c } +; CHECK: test4-fast +define half @test4-fast(float %a) { +; CHECK: fptrunc +; CHECK: fsub fast + %b = fsub fast float -0.0, %a + %c = fptrunc float %b to half + ret half %c +} + ; CHECK: test5 define half @test5(float %a, float %b, float %c) { ; CHECK: fcmp ogt