diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 0e464507a7e..bfdc17eff7e 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -475,7 +475,35 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } } break; - case Intrinsic::umul_with_overflow: + case Intrinsic::umul_with_overflow: { + Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); + unsigned BitWidth = cast(LHS->getType())->getBitWidth(); + APInt Mask = APInt::getAllOnesValue(BitWidth); + + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + ComputeMaskedBits(LHS, Mask, LHSKnownZero, LHSKnownOne); + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + ComputeMaskedBits(RHS, Mask, RHSKnownZero, RHSKnownOne); + + // Get the largest possible values for each operand, extended to be large + // enough so that every possible product of two BitWidth-sized ints fits. + APInt LHSMax = (~LHSKnownZero).zext(BitWidth*2); + APInt RHSMax = (~RHSKnownZero).zext(BitWidth*2); + + // If multiplying the maximum values does not overflow then we can turn + // this into a plain NUW mul. + if ((LHSMax * RHSMax).getActiveBits() <= BitWidth) { + Value *Mul = Builder->CreateNUWMul(LHS, RHS, "umul_with_overflow"); + Constant *V[] = { + UndefValue::get(LHS->getType()), + Builder->getFalse() + }; + Constant *Struct = ConstantStruct::get(II->getContext(), V, 2, false); + return InsertValueInst::Create(Struct, Mul, 0); + } + } // FALL THROUGH case Intrinsic::smul_with_overflow: // Canonicalize constants into the RHS. if (isa(II->getArgOperand(0)) && diff --git a/test/Transforms/InstCombine/intrinsics.ll b/test/Transforms/InstCombine/intrinsics.ll index 50e7f1f7c92..332cd46098c 100644 --- a/test/Transforms/InstCombine/intrinsics.ll +++ b/test/Transforms/InstCombine/intrinsics.ll @@ -112,6 +112,33 @@ define i8 @umultest2(i8 %A, i1* %overflowPtr) { ; CHECK-NEXT: ret i8 %A } +%ov.result.32 = type { i32, i1 } +declare %ov.result.32 @llvm.umul.with.overflow.i32(i32, i32) nounwind readnone + +define i32 @umultest3(i32 %n) nounwind { + %shr = lshr i32 %n, 2 + %mul = call %ov.result.32 @llvm.umul.with.overflow.i32(i32 %shr, i32 3) + %ov = extractvalue %ov.result.32 %mul, 1 + %res = extractvalue %ov.result.32 %mul, 0 + %ret = select i1 %ov, i32 -1, i32 %res + ret i32 %ret +; CHECK: @umultest3 +; CHECK-NEXT: shr +; CHECK-NEXT: mul nuw +; CHECK-NEXT: ret +} + +define i32 @umultest4(i32 %n) nounwind { + %shr = lshr i32 %n, 1 + %mul = call %ov.result.32 @llvm.umul.with.overflow.i32(i32 %shr, i32 4) + %ov = extractvalue %ov.result.32 %mul, 1 + %res = extractvalue %ov.result.32 %mul, 0 + %ret = select i1 %ov, i32 -1, i32 %res + ret i32 %ret +; CHECK: @umultest4 +; CHECK: umul.with.overflow +} + define void @powi(double %V, double *%P) { entry: %A = tail call double @llvm.powi.f64(double %V, i32 -1) nounwind