diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index f0da3e554c0..3d1faff0184 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -64,18 +64,6 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I) { // If we reach here, all incoming values are the same constant. return Result; - } else if (CallInst *CI = dyn_cast(I)) { - if (Function *F = CI->getCalledFunction()) - if (canConstantFoldCallTo(F)) { - std::vector Args; - for (unsigned i = 1, e = CI->getNumOperands(); i != e; ++i) - if (Constant *Op = dyn_cast(CI->getOperand(i))) - Args.push_back(Op); - else - return 0; - return ConstantFoldCall(F, Args); - } - return 0; } Constant *Op0 = 0, *Op1 = 0; @@ -91,37 +79,63 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I) { case 0: return 0; } - if (isa(I) || isa(I)) - return ConstantExpr::get(I->getOpcode(), Op0, Op1); + if (isa(I) || isa(I)) { + if (Constant *Op0 = dyn_cast(I->getOperand(0))) + if (Constant *Op1 = dyn_cast(I->getOperand(1))) + return ConstantExpr::get(I->getOpcode(), Op0, Op1); + return 0; // Operands not constants. + } - switch (I->getOpcode()) { + // Scan the operand list, checking to see if the are all constants, if so, + // hand off to ConstantFoldInstOperands. + std::vector Ops; + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Constant *Op = dyn_cast(I->getOperand(i))) + Ops.push_back(Op); + else + return 0; // All operands not constant! + + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Ops); +} + +/// ConstantFoldInstOperands - Attempt to constant fold an instruction with the +/// specified opcode and operands. If successful, the constant result is +/// returned, if not, null is returned. Note that this function can fail when +/// attempting to fold instructions like loads and stores, which have no +/// constant expression form. +/// +Constant *llvm::ConstantFoldInstOperands(unsigned Opc, const Type *DestTy, + const std::vector &Ops) { + if (Opc >= Instruction::BinaryOpsBegin && Opc < Instruction::BinaryOpsEnd) + return ConstantExpr::get(Opc, Ops[0], Ops[1]); + + switch (Opc) { default: return 0; + case Instruction::Call: + if (Function *F = dyn_cast(Ops[0])) { + if (canConstantFoldCallTo(F)) { + std::vector Args(Ops.begin()+1, Ops.end()); + return ConstantFoldCall(F, Args); + } + } + return 0; + case Instruction::Shl: + case Instruction::Shr: + return ConstantExpr::get(Opc, Ops[0], Ops[1]); case Instruction::Cast: - return ConstantExpr::getCast(Op0, I->getType()); + return ConstantExpr::getCast(Ops[0], DestTy); case Instruction::Select: - if (Constant *Op2 = dyn_cast(I->getOperand(2))) - return ConstantExpr::getSelect(Op0, Op1, Op2); - return 0; + return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]); case Instruction::ExtractElement: - return ConstantExpr::getExtractElement(Op0, Op1); + return ConstantExpr::getExtractElement(Ops[0], Ops[1]); case Instruction::InsertElement: - if (Constant *Op2 = dyn_cast(I->getOperand(2))) - return ConstantExpr::getInsertElement(Op0, Op1, Op2); - return 0; + return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]); case Instruction::ShuffleVector: - if (Constant *Op2 = dyn_cast(I->getOperand(2))) - return ConstantExpr::getShuffleVector(Op0, Op1, Op2); - return 0; + return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]); case Instruction::GetElementPtr: - std::vector IdxList; - IdxList.reserve(I->getNumOperands()-1); - if (Op1) IdxList.push_back(Op1); - for (unsigned i = 2, e = I->getNumOperands(); i != e; ++i) - if (Constant *C = dyn_cast(I->getOperand(i))) - IdxList.push_back(C); - else - return 0; // Non-constant operand - return ConstantExpr::getGetElementPtr(Op0, IdxList); + return ConstantExpr::getGetElementPtr(Ops[0], + std::vector(Ops.begin()+1, + Ops.end())); } }