diff --git a/lib/VMCore/ConstantFold.cpp b/lib/VMCore/ConstantFold.cpp index 8aeca7b47a5..dcd8657bd44 100644 --- a/lib/VMCore/ConstantFold.cpp +++ b/lib/VMCore/ConstantFold.cpp @@ -554,71 +554,55 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, // so look at directly computing the value. if (const ConstantInt *CI1 = dyn_cast(C1)) { if (const ConstantInt *CI2 = dyn_cast(C2)) { - if (CI1->getType() == Type::Int1Ty && CI2->getType() == Type::Int1Ty) { - switch (Opcode) { - default: - break; - case Instruction::And: - return ConstantInt::get(Type::Int1Ty, - CI1->getZExtValue() & CI2->getZExtValue()); - case Instruction::Or: - return ConstantInt::get(Type::Int1Ty, - CI1->getZExtValue() | CI2->getZExtValue()); - case Instruction::Xor: - return ConstantInt::get(Type::Int1Ty, - CI1->getZExtValue() ^ CI2->getZExtValue()); - } - } else { - uint64_t C1Val = CI1->getZExtValue(); - uint64_t C2Val = CI2->getZExtValue(); - switch (Opcode) { - default: - break; - case Instruction::Add: - return ConstantInt::get(C1->getType(), C1Val + C2Val); - case Instruction::Sub: - return ConstantInt::get(C1->getType(), C1Val - C2Val); - case Instruction::Mul: - return ConstantInt::get(C1->getType(), C1Val * C2Val); - case Instruction::UDiv: - if (CI2->isNullValue()) // X / 0 -> can't fold - return 0; - return ConstantInt::get(C1->getType(), C1Val / C2Val); - case Instruction::SDiv: - if (CI2->isNullValue()) return 0; // X / 0 -> can't fold - if (CI2->isAllOnesValue() && - (((CI1->getType()->getPrimitiveSizeInBits() == 64) && - (CI1->getSExtValue() == INT64_MIN)) || - (CI1->getSExtValue() == -CI1->getSExtValue()))) - return 0; // MIN_INT / -1 -> overflow - return ConstantInt::get(C1->getType(), - CI1->getSExtValue() / CI2->getSExtValue()); - case Instruction::URem: - if (C2->isNullValue()) return 0; // X / 0 -> can't fold - return ConstantInt::get(C1->getType(), C1Val % C2Val); - case Instruction::SRem: - if (CI2->isNullValue()) return 0; // X % 0 -> can't fold - if (CI2->isAllOnesValue() && - (((CI1->getType()->getPrimitiveSizeInBits() == 64) && - (CI1->getSExtValue() == INT64_MIN)) || - (CI1->getSExtValue() == -CI1->getSExtValue()))) - return 0; // MIN_INT % -1 -> overflow - return ConstantInt::get(C1->getType(), - CI1->getSExtValue() % CI2->getSExtValue()); - case Instruction::And: - return ConstantInt::get(C1->getType(), C1Val & C2Val); - case Instruction::Or: - return ConstantInt::get(C1->getType(), C1Val | C2Val); - case Instruction::Xor: - return ConstantInt::get(C1->getType(), C1Val ^ C2Val); - case Instruction::Shl: - return ConstantInt::get(C1->getType(), C1Val << C2Val); - case Instruction::LShr: - return ConstantInt::get(C1->getType(), C1Val >> C2Val); - case Instruction::AShr: - return ConstantInt::get(C1->getType(), - CI1->getSExtValue() >> C2Val); - } + uint64_t C1Val = CI1->getZExtValue(); + uint64_t C2Val = CI2->getZExtValue(); + switch (Opcode) { + default: + break; + case Instruction::Add: + return ConstantInt::get(C1->getType(), C1Val + C2Val); + case Instruction::Sub: + return ConstantInt::get(C1->getType(), C1Val - C2Val); + case Instruction::Mul: + return ConstantInt::get(C1->getType(), C1Val * C2Val); + case Instruction::UDiv: + if (CI2->isNullValue()) // X / 0 -> can't fold + return 0; + return ConstantInt::get(C1->getType(), C1Val / C2Val); + case Instruction::SDiv: + if (CI2->isNullValue()) return 0; // X / 0 -> can't fold + if (CI2->isAllOnesValue() && + (((CI1->getType()->getPrimitiveSizeInBits() == 64) && + (CI1->getSExtValue() == INT64_MIN)) || + (CI1->getSExtValue() == -CI1->getSExtValue()))) + return 0; // MIN_INT / -1 -> overflow + return ConstantInt::get(C1->getType(), + CI1->getSExtValue() / CI2->getSExtValue()); + case Instruction::URem: + if (C2->isNullValue()) return 0; // X / 0 -> can't fold + return ConstantInt::get(C1->getType(), C1Val % C2Val); + case Instruction::SRem: + if (CI2->isNullValue()) return 0; // X % 0 -> can't fold + if (CI2->isAllOnesValue() && + (((CI1->getType()->getPrimitiveSizeInBits() == 64) && + (CI1->getSExtValue() == INT64_MIN)) || + (CI1->getSExtValue() == -CI1->getSExtValue()))) + return 0; // MIN_INT % -1 -> overflow + return ConstantInt::get(C1->getType(), + CI1->getSExtValue() % CI2->getSExtValue()); + case Instruction::And: + return ConstantInt::get(C1->getType(), C1Val & C2Val); + case Instruction::Or: + return ConstantInt::get(C1->getType(), C1Val | C2Val); + case Instruction::Xor: + return ConstantInt::get(C1->getType(), C1Val ^ C2Val); + case Instruction::Shl: + return ConstantInt::get(C1->getType(), C1Val << C2Val); + case Instruction::LShr: + return ConstantInt::get(C1->getType(), C1Val >> C2Val); + case Instruction::AShr: + return ConstantInt::get(C1->getType(), + CI1->getSExtValue() >> C2Val); } } } else if (const ConstantFP *CFP1 = dyn_cast(C1)) { @@ -1059,34 +1043,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, return ConstantInt::getTrue(); } - if (isa(C1) && isa(C2) && - C1->getType() == Type::Int1Ty && C2->getType() == Type::Int1Ty) { - bool C1Val = cast(C1)->getZExtValue(); - bool C2Val = cast(C2)->getZExtValue(); - switch (pred) { - default: assert(0 && "Invalid ICmp Predicate"); return 0; - case ICmpInst::ICMP_EQ: - return ConstantInt::get(Type::Int1Ty, C1Val == C2Val); - case ICmpInst::ICMP_NE: - return ConstantInt::get(Type::Int1Ty, C1Val != C2Val); - case ICmpInst::ICMP_ULT: - return ConstantInt::get(Type::Int1Ty, C1Val < C2Val); - case ICmpInst::ICMP_UGT: - return ConstantInt::get(Type::Int1Ty, C1Val > C2Val); - case ICmpInst::ICMP_ULE: - return ConstantInt::get(Type::Int1Ty, C1Val <= C2Val); - case ICmpInst::ICMP_UGE: - return ConstantInt::get(Type::Int1Ty, C1Val >= C2Val); - case ICmpInst::ICMP_SLT: - return ConstantInt::get(Type::Int1Ty, C1Val < C2Val); - case ICmpInst::ICMP_SGT: - return ConstantInt::get(Type::Int1Ty, C1Val > C2Val); - case ICmpInst::ICMP_SLE: - return ConstantInt::get(Type::Int1Ty, C1Val <= C2Val); - case ICmpInst::ICMP_SGE: - return ConstantInt::get(Type::Int1Ty, C1Val >= C2Val); - } - } else if (isa(C1) && isa(C2)) { + if (isa(C1) && isa(C2)) { if (ICmpInst::isSignedPredicate(ICmpInst::Predicate(pred))) { int64_t V1 = cast(C1)->getSExtValue(); int64_t V2 = cast(C2)->getSExtValue();