diff --git a/lib/VMCore/ConstantFold.cpp b/lib/VMCore/ConstantFold.cpp index c1fcc5f4ed2..2c0a67f1d04 100644 --- a/lib/VMCore/ConstantFold.cpp +++ b/lib/VMCore/ConstantFold.cpp @@ -179,6 +179,151 @@ static Constant *FoldBitCast(LLVMContext &Context, } +/// ExtractConstantBytes - V is an integer constant which only has a subset of +/// its bytes used. The bytes used are indicated by ByteStart (which is the +/// first byte used, counting from the least significant byte) and ByteSize, +/// which is the number of bytes used. +/// +/// This function analyzes the specified constant to see if the specified byte +/// range can be returned as a simplified constant. If so, the constant is +/// returned, otherwise null is returned. +/// +static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart, + unsigned ByteSize) { + assert(isa(C->getType()) && + (cast(C->getType())->getBitWidth() & 7) == 0 && + "Non-byte sized integer input"); + unsigned CSize = cast(C->getType())->getBitWidth()/8; + assert(ByteSize && "Must be accessing some piece"); + assert(ByteStart+ByteSize <= CSize && "Extracting invalid piece from input"); + assert(ByteSize != CSize && "Should not extract everything"); + + // Constant Integers are simple. + if (ConstantInt *CI = dyn_cast(C)) { + APInt V = CI->getValue(); + if (ByteStart) + V = V.lshr(ByteStart*8); + V.trunc(ByteSize*8); + return ConstantInt::get(CI->getContext(), V); + } + + // In the input is a constant expr, we might be able to recursively simplify. + // If not, we definitely can't do anything. + ConstantExpr *CE = dyn_cast(C); + if (CE == 0) return 0; + + switch (CE->getOpcode()) { + default: return 0; + case Instruction::Or: { + Constant *RHS = ExtractConstantBytes(C->getOperand(1), ByteStart, ByteSize); + if (RHS == 0) + return 0; + + // X | -1 -> -1. + if (ConstantInt *RHSC = dyn_cast(RHS)) + if (RHSC->isAllOnesValue()) + return RHSC; + + Constant *LHS = ExtractConstantBytes(C->getOperand(0), ByteStart, ByteSize); + if (LHS == 0) + return 0; + return ConstantExpr::getOr(LHS, RHS); + } + case Instruction::And: { + Constant *RHS = ExtractConstantBytes(C->getOperand(1), ByteStart, ByteSize); + if (RHS == 0) + return 0; + + // X & 0 -> 0. + if (RHS->isNullValue()) + return RHS; + + Constant *LHS = ExtractConstantBytes(C->getOperand(0), ByteStart, ByteSize); + if (LHS == 0) + return 0; + return ConstantExpr::getAnd(LHS, RHS); + } + case Instruction::LShr: { + ConstantInt *Amt = dyn_cast(CE->getOperand(1)); + if (Amt == 0) + return 0; + unsigned ShAmt = Amt->getZExtValue(); + // Cannot analyze non-byte shifts. + if ((ShAmt & 7) != 0) + return 0; + ShAmt >>= 3; + + // If the extract is known to be all zeros, return zero. + if (ByteStart >= CSize-ShAmt) + return Constant::getNullValue(IntegerType::get(CE->getContext(), + ByteSize*8)); + // If the extract is known to be fully in the input, extract it. + if (ByteStart+ByteSize+ShAmt <= CSize) + return ExtractConstantBytes(C->getOperand(0), ByteStart+ShAmt, ByteSize); + + // TODO: Handle the 'partially zero' case. + return 0; + } + + case Instruction::Shl: { + ConstantInt *Amt = dyn_cast(CE->getOperand(1)); + if (Amt == 0) + return 0; + unsigned ShAmt = Amt->getZExtValue(); + // Cannot analyze non-byte shifts. + if ((ShAmt & 7) != 0) + return 0; + ShAmt >>= 3; + + // If the extract is known to be all zeros, return zero. + if (ByteStart+ByteSize <= ShAmt) + return Constant::getNullValue(IntegerType::get(CE->getContext(), + ByteSize*8)); + // If the extract is known to be fully in the input, extract it. + if (ByteStart >= ShAmt) + return ExtractConstantBytes(C->getOperand(0), ByteStart-ShAmt, ByteSize); + + // TODO: Handle the 'partially zero' case. + return 0; + } + + case Instruction::ZExt: { + unsigned SrcBitSize = + cast(C->getOperand(0)->getType())->getBitWidth(); + + // If extracting something that is completely zero, return 0. + if (ByteStart*8 >= SrcBitSize) + return Constant::getNullValue(IntegerType::get(CE->getContext(), + ByteSize*8)); + + // If exactly extracting the input, return it. + if (ByteStart == 0 && ByteSize*8 == SrcBitSize) + return C->getOperand(0); + + // If extracting something completely in the input, if if the input is a + // multiple of 8 bits, recurse. + if ((SrcBitSize&7) == 0 && (ByteStart+ByteSize)*8 <= SrcBitSize) + return ExtractConstantBytes(C->getOperand(0), ByteStart, ByteSize); + + // Otherwise, if extracting a subset of the input, which is not multiple of + // 8 bits, do a shift and trunc to get the bits. + if ((ByteStart+ByteSize)*8 < SrcBitSize) { + assert((SrcBitSize&7) && "Shouldn't get byte sized case here"); + Constant *Res = C->getOperand(0); + if (ByteStart) + Res = ConstantExpr::getLShr(Res, + ConstantInt::get(Res->getType(), ByteStart*8)); + return ConstantExpr::getTrunc(Res, IntegerType::get(C->getContext(), + ByteSize*8)); + } + + // TODO: Handle the 'partially zero' case. + return 0; + } + } +} + + Constant *llvm::ConstantFoldCastInstruction(LLVMContext &Context, unsigned opc, Constant *V, const Type *DestTy) { @@ -236,6 +381,8 @@ Constant *llvm::ConstantFoldCastInstruction(LLVMContext &Context, // We actually have to do a cast now. Perform the cast according to the // opcode specified. switch (opc) { + default: + llvm_unreachable("Failed to cast constant expression"); case Instruction::FPTrunc: case Instruction::FPExt: if (ConstantFP *FPC = dyn_cast(V)) { @@ -300,23 +447,27 @@ Constant *llvm::ConstantFoldCastInstruction(LLVMContext &Context, return ConstantInt::get(Context, Result); } return 0; - case Instruction::Trunc: + case Instruction::Trunc: { + uint32_t DestBitWidth = cast(DestTy)->getBitWidth(); if (ConstantInt *CI = dyn_cast(V)) { - uint32_t BitWidth = cast(DestTy)->getBitWidth(); APInt Result(CI->getValue()); - Result.trunc(BitWidth); + Result.trunc(DestBitWidth); return ConstantInt::get(Context, Result); } + + // The input must be a constantexpr. See if we can simplify this based on + // the bytes we are demanding. Only do this if the source and dest are an + // even multiple of a byte. + if ((DestBitWidth & 7) == 0 && + (cast(V->getType())->getBitWidth() & 7) == 0) + if (Constant *Res = ExtractConstantBytes(V, 0, DestBitWidth / 8)) + return Res; + return 0; + } case Instruction::BitCast: return FoldBitCast(Context, V, DestTy); - default: - assert(!"Invalid CE CastInst opcode"); - break; } - - llvm_unreachable("Failed to cast constant expression"); - return 0; } Constant *llvm::ConstantFoldSelectInstruction(LLVMContext&, diff --git a/test/Transforms/ConstProp/constant-expr.ll b/test/Transforms/ConstProp/constant-expr.ll index 89c1d926bd2..eece37fa69d 100644 --- a/test/Transforms/ConstProp/constant-expr.ll +++ b/test/Transforms/ConstProp/constant-expr.ll @@ -40,3 +40,21 @@ @O = global i1 icmp eq (i32 zext (i1 icmp ult (i8* @X, i8* @Y) to i32), i32 0) ; CHECK: @O = global i1 icmp uge (i8* @X, i8* @Y) + + +; PR5176 + +; CHECK: @T1 = global i1 true +@T1 = global i1 icmp eq (i64 and (i64 trunc (i256 lshr (i256 or (i256 and (i256 and (i256 shl (i256 zext (i64 ptrtoint (i1* @B to i64) to i256), i256 64), i256 -6277101735386680763495507056286727952638980837032266301441), i256 6277101735386680763835789423207666416102355444464034512895), i256 shl (i256 zext (i64 ptrtoint (i1* @A to i64) to i256), i256 192)), i256 64) to i64), i64 1), i64 0) + +; CHECK: @T2 = global i1* @B +@T2 = global i1* inttoptr (i64 add (i64 trunc (i256 lshr (i256 or (i256 and (i256 and (i256 shl (i256 zext (i64 ptrtoint (i1* @A to i64) to i256), i256 64), i256 -6277101735386680763495507056286727952638980837032266301441), i256 6277101735386680763835789423207666416102355444464034512895), i256 shl (i256 zext (i64 ptrtoint (i1* @B to i64) to i256), i256 192)), i256 192) to i64), i64 trunc (i256 lshr (i256 or (i256 and (i256 and (i256 shl (i256 zext (i64 ptrtoint (i1* @A to i64) to i256), i256 64), i256 -6277101735386680763495507056286727952638980837032266301441), i256 6277101735386680763835789423207666416102355444464034512895), i256 shl (i256 zext (i64 ptrtoint (i1* @B to i64) to i256), i256 192)), i256 128) to i64)) to i1*) + +; CHECK: @T3 = global i64 add (i64 ptrtoint (i1* @B to i64), i64 -1) +@T3 = global i64 add (i64 trunc (i256 lshr (i256 or (i256 and (i256 and (i256 shl (i256 zext (i64 ptrtoint (i1* @B to i64) to i256), i256 64), i256 -6277101735386680763495507056286727952638980837032266301441), i256 6277101735386680763835789423207666416102355444464034512895), i256 shl (i256 zext (i64 ptrtoint (i1* @A to i64) to i256), i256 192)), i256 64) to i64), i64 -1) + +; CHECK: @T4 = global i1* @B +@T4 = global i1* inttoptr (i64 trunc (i256 lshr (i256 or (i256 and (i256 and (i256 shl (i256 zext (i64 ptrtoint (i1* @B to i64) to i256), i256 64), i256 -6277101735386680763495507056286727952638980837032266301441), i256 6277101735386680763835789423207666416102355444464034512895), i256 shl (i256 zext (i64 ptrtoint (i1* @A to i64) to i256), i256 192)), i256 64) to i64) to i1*) + +; CHECK: @T5 = global i1* @A +@T5 = global i1* inttoptr (i64 add (i64 trunc (i256 lshr (i256 or (i256 and (i256 and (i256 shl (i256 zext (i64 ptrtoint (i1* @B to i64) to i256), i256 64), i256 -6277101735386680763495507056286727952638980837032266301441), i256 6277101735386680763835789423207666416102355444464034512895), i256 shl (i256 zext (i64 ptrtoint (i1* @A to i64) to i256), i256 192)), i256 192) to i64), i64 trunc (i256 lshr (i256 or (i256 and (i256 and (i256 shl (i256 zext (i64 ptrtoint (i1* @B to i64) to i256), i256 64), i256 -6277101735386680763495507056286727952638980837032266301441), i256 6277101735386680763835789423207666416102355444464034512895), i256 shl (i256 zext (i64 ptrtoint (i1* @A to i64) to i256), i256 192)), i256 128) to i64)) to i1*) \ No newline at end of file