diff --git a/lib/VMCore/ConstantFold.cpp b/lib/VMCore/ConstantFold.cpp index 20de93cce52..d35ace0af15 100644 --- a/lib/VMCore/ConstantFold.cpp +++ b/lib/VMCore/ConstantFold.cpp @@ -444,6 +444,8 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, return Constant::getNullValue(C1->getType()); return const_cast(C2); // X / undef -> undef case Instruction::Or: // X | undef -> -1 + if (const PackedType *PTy = dyn_cast(C1->getType())) + return ConstantPacked::getAllOnesValue(PTy); return ConstantInt::getAllOnesValue(C1->getType()); case Instruction::LShr: if (isa(C2) && isa(C1)) @@ -496,8 +498,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, return Constant::getNullValue(CI->getType()); // X % 1 == 0 break; case Instruction::And: - if (cast(C2)->isAllOnesValue()) - return const_cast(C1); // X & -1 == X + if (const ConstantInt *CI = dyn_cast(C2)) + if (CI->isAllOnesValue()) + return const_cast(C1); // X & -1 == X if (C2->isNullValue()) return const_cast(C2); // X & 0 == 0 if (CE1->isCast() && isa(CE1->getOperand(0))) { GlobalValue *CPR = cast(CE1->getOperand(0)); @@ -511,8 +514,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, break; case Instruction::Or: if (C2->isNullValue()) return const_cast(C1); // X | 0 == X - if (cast(C2)->isAllOnesValue()) - return const_cast(C2); // X | -1 == -1 + if (const ConstantInt *CI = dyn_cast(C2)) + if (CI->isAllOnesValue()) + return const_cast(C2); // X | -1 == -1 break; case Instruction::Xor: if (C2->isNullValue()) return const_cast(C1); // X ^ 0 == X @@ -547,7 +551,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, } // At this point we know neither constant is an UndefValue nor a ConstantExpr - // so look at directly computing the + // so look at directly computing the value. if (const ConstantBool *CB1 = dyn_cast(C1)) { if (const ConstantBool *CB2 = dyn_cast(C2)) { switch (Opcode) { @@ -606,10 +610,16 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, case Instruction::Xor: return ConstantInt::get(C1->getType(), C1Val ^ C2Val); case Instruction::Shl: + if (C2Val >= CI1->getType()->getPrimitiveSizeInBits()) + C2Val = CI1->getType()->getPrimitiveSizeInBits() - 1; return ConstantInt::get(C1->getType(), C1Val << C2Val); case Instruction::LShr: + if (C2Val >= CI1->getType()->getPrimitiveSizeInBits()) + C2Val = CI1->getType()->getPrimitiveSizeInBits() - 1; return ConstantInt::get(C1->getType(), C1Val >> C2Val); case Instruction::AShr: + if (C2Val >= CI1->getType()->getPrimitiveSizeInBits()) + C2Val = CI1->getType()->getPrimitiveSizeInBits() - 1; return ConstantInt::get(C1->getType(), CI1->getSExtValue() >> C2Val); }