diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 41eea27f003..80497e9391a 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1583,6 +1583,62 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { return BinaryOperator::CreateNot(Result); } +/// ProcessUGT_ADDCST_ADD - The caller has matched a pattern of the form: +/// I = icmp ugt (add (add A, B), CI2), CI1 +static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, + ConstantInt *CI2, ConstantInt *CI1, + InstCombiner::BuilderTy *Builder) { + const IntegerType *WideType = cast(CI1->getType()); + unsigned WideWidth = WideType->getBitWidth(); + unsigned NarrowWidth = WideWidth / 2; + const IntegerType *NarrowType = + IntegerType::get(CI1->getContext(), NarrowWidth); + + // NarrowAllOnes and NarrowSignBit are the magic constants used to + // perform an overflow check in the wider type: 0x00..00FF..FF and + // 0x00..0010..00 respectively, where the highest set bit in each is + // what would be the sign bit in the narrower type. + ConstantInt *NarrowAllOnes = cast(ConstantInt::get(WideType, + APInt::getAllOnesValue(NarrowWidth).zext(WideWidth))); + APInt SignBit(WideWidth, 0); + SignBit.setBit(NarrowWidth-1); + ConstantInt *NarrowSignBit = + cast(ConstantInt::get(WideType, SignBit)); + + if (CI1 != NarrowAllOnes || CI2 != NarrowSignBit) + return 0; + + Module *M = I.getParent()->getParent()->getParent(); + + const Type *IntrinsicType = NarrowType; + Value *F = Intrinsic::getDeclaration(M, Intrinsic::sadd_with_overflow, + &IntrinsicType, 1); + + BasicBlock *InitialBlock = Builder->GetInsertBlock(); + BasicBlock::iterator InitialInsert = Builder->GetInsertPoint(); + + // If the pattern matches, truncate the inputs to the narrower type and + // use the sadd_with_overflow intrinsic to efficiently compute both the + // result and the overflow bit. + Instruction *OrigAdd = + cast(cast(I.getOperand(0))->getOperand(0)); + Builder->SetInsertPoint(OrigAdd->getParent(), + BasicBlock::iterator(OrigAdd)); + + Value *TruncA = Builder->CreateTrunc(A, NarrowType, A->getName()); + Value *TruncB = Builder->CreateTrunc(B, NarrowType, B->getName()); + CallInst *Call = Builder->CreateCall2(F, TruncA, TruncB); + Value *Add = Builder->CreateExtractValue(Call, 0); + Value *ZExt = Builder->CreateZExt(Add, WideType); + + // The inner add was the result of the narrow add, zero extended to the + // wider type. Replace it with the result computed by the intrinsic. + OrigAdd->replaceAllUsesWith(ZExt); + + Builder->SetInsertPoint(InitialBlock, InitialInsert); + + return ExtractValueInst::Create(Call, 1); +} Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { @@ -1662,72 +1718,19 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // addition in wider type, and explicitly checks for overflow using // comparisons against INT_MIN and INT_MAX. Simplify this by using the // sadd_with_overflow intrinsic. - // FIXME: This could probably be generalized to handle other overflow-safe + // + // TODO: This could probably be generalized to handle other overflow-safe // operations if we worked out the formulas to compute the appropriate // magic constants. // - // INT64 : a, b, sum = a + b - // if sum < INT32_MIN || sum > INT_MAX then - // ... - // else - // ... + // sum = a + b + // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 { - ConstantInt *CI2; - - // I = icmp ugt (add (add A B) CI2) CI + ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI if (I.getPredicate() == ICmpInst::ICMP_UGT && - match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), - m_ConstantInt(CI2)))) { - const IntegerType *WideType = cast(CI->getType()); - unsigned WideWidth = WideType->getBitWidth(); - unsigned NarrowWidth = WideWidth / 2; - const IntegerType *NarrowType = - IntegerType::get(CI->getContext(), NarrowWidth); - - // NarrowAllOnes and NarrowSignBit are the magic constants used to - // perform an overflow check in the wider type: 0x00..00FF..FF and - // 0x00..0010..00 respectively, where the highest set bit in each is - // what would be the sign bit in the narrower type. - ConstantInt *NarrowAllOnes = cast(ConstantInt::get(WideType, - APInt::getAllOnesValue(NarrowWidth).zext(WideWidth))); - APInt SignBit(WideWidth, 0); - SignBit.setBit(NarrowWidth-1); - ConstantInt *NarrowSignBit = - cast(ConstantInt::get(WideType, SignBit)); - - if (CI == NarrowAllOnes && CI2 == NarrowSignBit) { - Module *M = I.getParent()->getParent()->getParent(); - - const Type *IntrinsicType = NarrowType; - Value *F = Intrinsic::getDeclaration(M, Intrinsic::sadd_with_overflow, - &IntrinsicType, 1); - - BasicBlock *InitialBlock = Builder->GetInsertBlock(); - BasicBlock::iterator InitialInsert = Builder->GetInsertPoint(); - - // If the pattern matches, truncate the inputs to the narrower type and - // use the sadd_with_overflow intrinsic to efficiently compute both the - // result and the overflow bit. - Instruction *OrigAdd = - cast(cast(I.getOperand(0))->getOperand(0)); - Builder->SetInsertPoint(OrigAdd->getParent(), - BasicBlock::iterator(OrigAdd)); - - Value *TruncA = Builder->CreateTrunc(A, NarrowType, A->getName()); - Value *TruncB = Builder->CreateTrunc(B, NarrowType, B->getName()); - CallInst *Call = Builder->CreateCall2(F, TruncA, TruncB); - Value *Add = Builder->CreateExtractValue(Call, 0); - Value *ZExt = Builder->CreateZExt(Add, WideType); - - // The inner add was the result of the narrow add, zero extended to the - // wider type. Replace it with the result computed by the intrinsic. - OrigAdd->replaceAllUsesWith(ZExt); - - Builder->SetInsertPoint(InitialBlock, InitialInsert); - - return ExtractValueInst::Create(Call, 1); - } - } + match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, Builder)) + return Res; } // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B)