diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 336c32940be..acd78d6bf9f 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -153,7 +153,7 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, /// whether promoting or shrinking integer operations to wider or smaller types /// will allow us to eliminate a truncate or extend. /// -/// This is a truncation operation if Ty is smaller than V->getType(), or an +/// This is a truncation operation if Ty is smaller than V->getType(), or a zero /// extension operation if Ty is larger. /// /// If CastOpc is a truncation, then Ty will be a type smaller than V. We @@ -162,11 +162,13 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, /// inst(trunc(x),trunc(y)), which only makes sense if x and y can be /// efficiently truncated. /// -/// If CastOpc is a sext or zext, we are asking if the low bits of the value can -/// bit computed in a larger type, which is then and'd or sext_in_reg'd to get -/// the final result. +/// If CastOpc is zext, we are asking if the low bits of the value can bit +/// computed in a larger type, which is then and'd to get the final result. static bool CanEvaluateInDifferentType(Value *V, const Type *Ty, - unsigned CastOpc, int &NumCastsRemoved) { + unsigned CastOpc, + unsigned &NumCastsRemoved) { + assert(CastOpc == Instruction::ZExt || CastOpc == Instruction::Trunc); + // We can always evaluate constants in another type. if (isa(V)) return true; @@ -291,9 +293,124 @@ static bool CanEvaluateInDifferentType(Value *V, const Type *Ty, return false; } +/// CanEvaluateSExtd - Return true if we can take the specified value +/// and return it as type Ty without inserting any new casts and without +/// changing the computed value of the common low bits. This is used by code +/// that tries to promote integer operations to a wider types will allow us to +/// eliminate the extension. +/// +/// This returns 0 if we can't do this or the number of sign bits that would be +/// set if we can. For example, CanEvaluateSExtd(i16 1, i64) would return 63, +/// because the computation can be extended (to "i64 1") and the resulting +/// computation has 63 equal sign bits. +/// +/// This function works on both vectors and scalars. For vectors, the result is +/// the number of bits known sign extended in each element. +/// +static unsigned CanEvaluateSExtd(Value *V, const Type *Ty, + unsigned &NumCastsRemoved, TargetData *TD) { + assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() && + "Can't sign extend type to a smaller type"); + // If this is a constant, return the number of sign bits the extended version + // of it would have. + if (Constant *C = dyn_cast(V)) + return ComputeNumSignBits(ConstantExpr::getSExt(C, Ty), TD); + + Instruction *I = dyn_cast(V); + if (!I) return 0; + + // If this is a truncate from the destination type, we can trivially eliminate + // it, and this will remove a cast overall. + if (isa(I) && I->getOperand(0)->getType() == Ty) { + // If the operand of the truncate is itself a cast, and is eliminable, do + // not count this as an eliminable cast. We would prefer to eliminate those + // two casts first. + if (!isa(I->getOperand(0)) && I->hasOneUse()) + ++NumCastsRemoved; + return ComputeNumSignBits(I->getOperand(0), TD); + } + + // We can't extend or shrink something that has multiple uses: doing so would + // require duplicating the instruction in general, which isn't profitable. + if (!I->hasOneUse()) return 0; + + const Type *OrigTy = V->getType(); + + unsigned Opc = I->getOpcode(); + unsigned Tmp1, Tmp2; + switch (Opc) { + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // These operators can all arbitrarily be extended or truncated. + Tmp1 = CanEvaluateSExtd(I->getOperand(0), Ty, NumCastsRemoved, TD); + if (Tmp1 == 0) return 0; + Tmp2 = CanEvaluateSExtd(I->getOperand(1), Ty, NumCastsRemoved, TD); + return std::min(Tmp1, Tmp2); + case Instruction::Add: + case Instruction::Sub: + // Add/Sub can have at most one carry/borrow bit. + Tmp1 = CanEvaluateSExtd(I->getOperand(0), Ty, NumCastsRemoved, TD); + if (Tmp1 == 0) return 0; + Tmp2 = CanEvaluateSExtd(I->getOperand(1), Ty, NumCastsRemoved, TD); + if (Tmp2 == 0) return 0; + return std::min(Tmp1, Tmp2)-1; + case Instruction::Mul: + // These operators can all arbitrarily be extended or truncated. + if (!CanEvaluateSExtd(I->getOperand(0), Ty, NumCastsRemoved, TD)) + return 0; + if (!CanEvaluateSExtd(I->getOperand(1), Ty, NumCastsRemoved, TD)) + return 0; + return 1; // IMPROVE? + + //case Instruction::Shl: TODO + //case Instruction::LShr: TODO + //case Instruction::Trunc: TODO + + case Instruction::SExt: + case Instruction::ZExt: { + // sext(sext(x)) -> sext(x) + // sext(zext(x)) -> zext(x) + // Note that replacing a cast does not reduce the number of casts in the + // input. + unsigned InSignBits = ComputeNumSignBits(I, TD); + unsigned ExtBits = Ty->getScalarSizeInBits()-OrigTy->getScalarSizeInBits(); + // We'll end up extending it all the way out. + return InSignBits+ExtBits; + } + case Instruction::Select: { + SelectInst *SI = cast(I); + Tmp1 = CanEvaluateSExtd(SI->getTrueValue(), Ty, NumCastsRemoved, TD); + if (Tmp1 == 0) return 0; + Tmp2 = CanEvaluateSExtd(SI->getFalseValue(), Ty, NumCastsRemoved,TD); + return std::min(Tmp1, Tmp2); + } + case Instruction::PHI: { + // We can change a phi if we can change all operands. Note that we never + // get into trouble with cyclic PHIs here because we only consider + // instructions with a single use. + PHINode *PN = cast(I); + unsigned Result = ~0U; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Result = std::min(Result, + CanEvaluateSExtd(PN->getIncomingValue(i), Ty, + NumCastsRemoved, TD)); + if (Result == 0) return 0; + } + return Result; + } + default: + // TODO: Can handle more cases here. + break; + } + + return 0; +} + + /// EvaluateInDifferentType - Given an expression that -/// CanEvaluateInDifferentType returns true for, actually insert the code to -/// evaluate the expression. +/// CanEvaluateInDifferentType or CanEvaluateSExtd returns true for, actually +/// insert the code to evaluate the expression. Value *InstCombiner::EvaluateInDifferentType(Value *V, const Type *Ty, bool isSigned) { if (Constant *C = dyn_cast(V)) @@ -469,35 +586,68 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { return 0; // Attempt to propagate the cast into the instruction for int->int casts. - int NumCastsRemoved = 0; - if (!CanEvaluateInDifferentType(Src, DestTy, CI.getOpcode(), NumCastsRemoved)) - return 0; - + unsigned NumCastsRemoved = 0; switch (CI.getOpcode()) { default: assert(0 && "not an integer cast"); case Instruction::Trunc: + if (!CanEvaluateInDifferentType(Src, DestTy, + Instruction::Trunc, NumCastsRemoved)) + return 0; + // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. break; case Instruction::ZExt: + if (!CanEvaluateInDifferentType(Src, DestTy, + Instruction::ZExt, NumCastsRemoved)) + return 0; + // If this is a zero-extension, we need to do an AND to maintain the clear // top-part of the computation, so we require that the input have eliminated // at least one cast. if (NumCastsRemoved < 1) return 0; break; - case Instruction::SExt: - // If this is a sign extension, we insert two new shifts (to do the - // extension) so we require that two casts have been eliminated. - if (NumCastsRemoved < 2) + case Instruction::SExt: { + // Check to see if we can do this transformation, and if so, how many bits + // of the promoted expression will be known copies of the sign bit in the + // result. + unsigned NumBitsSExt = CanEvaluateSExtd(Src, DestTy, NumCastsRemoved, TD); + if (NumBitsSExt == 0) return 0; - break; + + uint32_t SrcBitSize = SrcTy->getScalarSizeInBits(); + uint32_t DestBitSize = DestTy->getScalarSizeInBits(); + + // Because this is a sign extension, we can always transform it by inserting + // two new shifts (to do the extension). However, this is only profitable + // if we've eliminated two or more casts from the input. If we know the + // result will be sign-extendy enough to not require these shifts, we can + // always do the transformation. + if (NumCastsRemoved < 2 && + NumBitsSExt <= DestBitSize-SrcBitSize) + return 0; + + // Okay, we can transform this! Insert the new expression now. + DEBUG(errs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid sign extend: " << CI); + Value *Res = EvaluateInDifferentType(Src, DestTy, true); + assert(Res->getType() == DestTy); + + // If the high bits are already filled with sign bit, just replace this + // cast with the result. + if (NumBitsSExt > DestBitSize - SrcBitSize || + ComputeNumSignBits(Res) > DestBitSize - SrcBitSize) + return ReplaceInstUsesWith(CI, Res); + + // We need to emit a cast to truncate, then a cast to sext. + return new SExtInst(Builder->CreateTrunc(Res, Src->getType()), DestTy); + } } DEBUG(errs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid cast: " << CI); - Value *Res = EvaluateInDifferentType(Src, DestTy, - CI.getOpcode() == Instruction::SExt); + Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); uint32_t SrcBitSize = SrcTy->getScalarSizeInBits(); diff --git a/test/Transforms/InstCombine/cast-sext-zext.ll b/test/Transforms/InstCombine/cast-sext-zext.ll deleted file mode 100644 index 678874a2794..00000000000 --- a/test/Transforms/InstCombine/cast-sext-zext.ll +++ /dev/null @@ -1,21 +0,0 @@ -; RUN: opt < %s -instcombine -S | not grep sext -; XFAIL: * -; rdar://6598839 - -define zeroext i16 @t(i8 zeroext %on_off, i16* nocapture %puls) nounwind readonly { -entry: - %0 = zext i8 %on_off to i32 - %1 = add i32 %0, -1 - %2 = sext i32 %1 to i64 - %3 = getelementptr i16* %puls, i64 %2 - %4 = load i16* %3, align 2 - ret i16 %4 -} - -define zeroext i64 @t2(i8 zeroext %on_off) nounwind readonly { -entry: - %0 = zext i8 %on_off to i32 - %1 = add i32 %0, -1 - %2 = sext i32 %1 to i64 - ret i64 %2 ;; Should be (add (zext i8 -> i64), -1) -} diff --git a/test/Transforms/InstCombine/cast.ll b/test/Transforms/InstCombine/cast.ll index a6c6795e844..10e5050125d 100644 --- a/test/Transforms/InstCombine/cast.ll +++ b/test/Transforms/InstCombine/cast.ll @@ -381,3 +381,14 @@ define i32 @test42(i32 %X) { ; CHECK: %Z = and i32 %X, 255 } +; rdar://6598839 +define zeroext i64 @test43(i8 zeroext %on_off) nounwind readonly { + %A = zext i8 %on_off to i32 + %B = add i32 %A, -1 + %C = sext i32 %B to i64 + ret i64 %C ;; Should be (add (zext i8 -> i64), -1) +; CHECK: @test43 +; CHECK-NEXT: %A = zext i8 %on_off to i64 +; CHECK-NEXT: %B = add i64 %A, -1 +; CHECK-NEXT: ret i64 %B +}