diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 31f82ad1964..36a2ad5e455 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -394,8 +394,7 @@ namespace { Value *EvaluateInDifferentType(Value *V, const Type *Ty, bool isSigned); bool CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, - unsigned CastOpc, - int &NumCastsRemoved, bool &SeenTrunc); + unsigned CastOpc, int &NumCastsRemoved); unsigned GetOrEnforceKnownAlignment(Value *V, unsigned PrefAlign = 0); @@ -7497,10 +7496,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, /// If CastOpc is a sext or zext, we are asking if the low bits of the value can /// bit computed in a larger type, which is then and'd or sext_in_reg'd to get /// the final result. -bool -InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, - unsigned CastOpc, - int &NumCastsRemoved, bool &SeenTrunc){ +bool InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, + unsigned CastOpc, + int &NumCastsRemoved){ // We can always evaluate constants in another type. if (isa(V)) return true; @@ -7520,8 +7518,6 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, // casts first. if (!isa(I->getOperand(0)) && I->hasOneUse()) ++NumCastsRemoved; - if (isa(I)) - SeenTrunc = true; return true; } } @@ -7540,9 +7536,9 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, case Instruction::Xor: // These operators can all arbitrarily be extended or truncated. return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc, - NumCastsRemoved, SeenTrunc) && + NumCastsRemoved) && CanEvaluateInDifferentType(I->getOperand(1), Ty, CastOpc, - NumCastsRemoved, SeenTrunc); + NumCastsRemoved); case Instruction::Shl: // If we are truncating the result of this SHL, and if it's a shift of a @@ -7552,7 +7548,7 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, if (BitWidth < OrigTy->getBitWidth() && CI->getLimitedValue(BitWidth) < BitWidth) return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc, - NumCastsRemoved, SeenTrunc); + NumCastsRemoved); } break; case Instruction::LShr: @@ -7567,7 +7563,7 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && CI->getLimitedValue(BitWidth) < BitWidth) { return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc, - NumCastsRemoved, SeenTrunc); + NumCastsRemoved); } } break; @@ -7587,16 +7583,16 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, case Instruction::Select: { SelectInst *SI = cast(I); return CanEvaluateInDifferentType(SI->getTrueValue(), Ty, CastOpc, - NumCastsRemoved, SeenTrunc) && + NumCastsRemoved) && CanEvaluateInDifferentType(SI->getFalseValue(), Ty, CastOpc, - NumCastsRemoved, SeenTrunc); + NumCastsRemoved); } case Instruction::PHI: { // We can change a phi if we can change all operands. PHINode *PN = cast(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (!CanEvaluateInDifferentType(PN->getIncomingValue(i), Ty, CastOpc, - NumCastsRemoved, SeenTrunc)) + NumCastsRemoved)) return false; return true; } @@ -7845,10 +7841,9 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { // Attempt to propagate the cast into the instruction for int->int casts. int NumCastsRemoved = 0; - bool SeenTrunc = false; if (!isa(CI) && CanEvaluateInDifferentType(SrcI, cast(DestTy), - CI.getOpcode(), NumCastsRemoved, SeenTrunc)) { + CI.getOpcode(), NumCastsRemoved)) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. If this is a zero-extension, // we need to do an AND to maintain the clear top-part of the computation, @@ -7865,14 +7860,27 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { case Instruction::Trunc: DoXForm = true; break; - case Instruction::ZExt: + case Instruction::ZExt: { DoXForm = NumCastsRemoved >= 1; - // TODO: Check if we need to insert an AND. + if (!DoXForm) { + // If it's unnecessary to issue an AND to clear the high bits, it's + // always profitable to do this xform. + Value *TryRes = EvaluateInDifferentType(SrcI, DestTy, + CI.getOpcode() == Instruction::SExt); + APInt Mask(APInt::getBitsSet(DestBitSize, SrcBitSize, DestBitSize)); + if (MaskedValueIsZero(TryRes, Mask)) + return ReplaceInstUsesWith(CI, TryRes); + else if (Instruction *TryI = dyn_cast(TryRes)) + if (TryI->use_empty()) + EraseInstFromFunction(*TryI); + } break; + } case Instruction::SExt: { DoXForm = NumCastsRemoved >= 2; - if (!SeenTrunc) { - // Do we have to emit a truncate to SrcBitSize followed by a sext? + if (!DoXForm && !isa(SrcI)) { + // If we do not have to emit the truncate + sext pair, then it's always + // profitable to do this xform. // // It's not safe to eliminate the trunc + sext pair if one of the // eliminated cast is a truncate. e.g. @@ -7880,11 +7888,14 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { // t3 = sext i16 t2 to i32 // != // i32 t1 - unsigned NumSignBits = ComputeNumSignBits(&CI); - if (NumSignBits > (DestBitSize - SrcBitSize)) { - DoXForm = true; - JustReplace = true; - } + Value *TryRes = EvaluateInDifferentType(SrcI, DestTy, + CI.getOpcode() == Instruction::SExt); + unsigned NumSignBits = ComputeNumSignBits(TryRes); + if (NumSignBits > (DestBitSize - SrcBitSize)) + return ReplaceInstUsesWith(CI, TryRes); + else if (Instruction *TryI = dyn_cast(TryRes)) + if (TryI->use_empty()) + EraseInstFromFunction(*TryI); } break; } @@ -7893,6 +7904,10 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { if (DoXForm) { Value *Res = EvaluateInDifferentType(SrcI, DestTy, CI.getOpcode() == Instruction::SExt); + if (JustReplace) + // Just replace this cast with the result. + return ReplaceInstUsesWith(CI, Res); + assert(Res->getType() == DestTy); switch (CI.getOpcode()) { default: assert(0 && "Unknown cast type!"); @@ -7901,15 +7916,24 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { // Just replace this cast with the result. return ReplaceInstUsesWith(CI, Res); case Instruction::ZExt: { - // We need to emit an AND to clear the high bits. assert(SrcBitSize < DestBitSize && "Not a zext?"); + + // If the high bits are already zero, just replace this cast with the + // result. + APInt Mask(APInt::getBitsSet(DestBitSize, SrcBitSize, DestBitSize)); + if (MaskedValueIsZero(Res, Mask)) + return ReplaceInstUsesWith(CI, Res); + + // We need to emit an AND to clear the high bits. Constant *C = ConstantInt::get(APInt::getLowBitsSet(DestBitSize, SrcBitSize)); return BinaryOperator::CreateAnd(Res, C); } - case Instruction::SExt: - if (JustReplace) - // Just replace this cast with the result. + case Instruction::SExt: { + // If the high bits are already filled with sign bit, just replace this + // cast with the result. + unsigned NumSignBits = ComputeNumSignBits(Res); + if (NumSignBits > (DestBitSize - SrcBitSize)) return ReplaceInstUsesWith(CI, Res); // We need to emit a cast to truncate, then a cast to sext. @@ -7917,6 +7941,7 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { InsertCastBefore(Instruction::Trunc, Res, Src->getType(), CI), DestTy); } + } } } diff --git a/test/Transforms/InstCombine/cast.ll b/test/Transforms/InstCombine/cast.ll index 9361ff24e9b..7a1e7a802dd 100644 --- a/test/Transforms/InstCombine/cast.ll +++ b/test/Transforms/InstCombine/cast.ll @@ -254,3 +254,10 @@ define i1 @test37(i32 %a) { ret i1 %e } +define i64 @test38(i32 %a) { + %1 = icmp eq i32 %a, -2 + %2 = zext i1 %1 to i8 + %3 = xor i8 %2, 1 + %4 = zext i8 %3 to i64 + ret i64 %4 +}