diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index b132b563031..c5494aab33b 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -17,6 +17,7 @@ #include "llvm/Constants.h" #include "llvm/DerivedTypes.h" #include "llvm/Instructions.h" +#include "llvm/Intrinsics.h" #include "llvm/Module.h" #include "llvm/Pass.h" #include "llvm/Analysis/Dominators.h" @@ -55,9 +56,28 @@ namespace { bool isEligible(const std::vector &code); private: - void findInputsOutputs(Values &inputs, Values &outputs, - BasicBlock *newHeader, - BasicBlock *newRootNode); + /// definedInRegion - Return true if the specified value is defined in the + /// extracted region. + bool definedInRegion(Value *V) const { + if (Instruction *I = dyn_cast(V)) + if (BlocksToExtract.count(I->getParent())) + return true; + return false; + } + + /// definedInCaller - Return true if the specified value is defined in the + /// function being code extracted, but not in the region being extracted. + /// These values must be passed in as live-ins to the function. + bool definedInCaller(Value *V) const { + if (isa(V)) return true; + if (Instruction *I = dyn_cast(V)) + if (!BlocksToExtract.count(I->getParent())) + return true; + return false; + } + + void severSplitPHINodes(BasicBlock *&Header); + void findInputsOutputs(Values &inputs, Values &outputs); Function *constructFunction(const Values &inputs, const Values &outputs, @@ -75,51 +95,40 @@ namespace { }; } -void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs, - BasicBlock *newHeader, - BasicBlock *newRootNode) { +/// severSplitPHINodes - If a PHI node has multiple inputs from outside of the +/// region, we need to split the entry block of the region so that the PHI node +/// is easier to deal with. +void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { + + + +} + +// findInputsOutputs - Find inputs to, outputs from the code region. +// +void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs) { std::set ExitBlocks; for (std::set::const_iterator ci = BlocksToExtract.begin(), ce = BlocksToExtract.end(); ci != ce; ++ci) { BasicBlock *BB = *ci; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. - if (PHINode *PN = dyn_cast(I)) { - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { - Value *V = PN->getIncomingValue(i); - if (!BlocksToExtract.count(PN->getIncomingBlock(i)) && - (isa(V) || isa(V))) - inputs.push_back(V); - else if (Instruction *opI = dyn_cast(V)) { - if (!BlocksToExtract.count(opI->getParent())) - inputs.push_back(opI); - } else if (isa(V)) - inputs.push_back(V); - } - } else { - // All other instructions go through the generic input finder - // Loop over the operands of each instruction (inputs) - for (User::op_iterator op = I->op_begin(), opE = I->op_end(); - op != opE; ++op) - if (Instruction *opI = dyn_cast(*op)) { - // Check if definition of this operand is within the loop - if (!BlocksToExtract.count(opI->getParent())) - inputs.push_back(opI); - } else if (isa(*op)) { - inputs.push_back(*op); - } - } + for (User::op_iterator O = I->op_begin(), E = I->op_end(); O != E; ++O) + if (definedInCaller(*O)) + inputs.push_back(*O); - // Consider uses of this instruction (outputs) + // Consider uses of this instruction (outputs). for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E; ++UI) - if (!BlocksToExtract.count(cast(*UI)->getParent())) { + if (!definedInRegion(*UI)) { outputs.push_back(I); break; } } // for: insts + // Keep track of the exit blocks from the region. TerminatorInst *TI = BB->getTerminator(); for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) if (!BlocksToExtract.count(TI->getSuccessor(i))) @@ -238,30 +247,14 @@ Function *CodeExtractor::constructFunction(const Values &inputs, return newFunction; } -void CodeExtractor::moveCodeToFunction(Function *newFunction) { - Function *oldFunc = (*BlocksToExtract.begin())->getParent(); - Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); - Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); - - for (std::set::const_iterator i = BlocksToExtract.begin(), - e = BlocksToExtract.end(); i != e; ++i) { - // Delete the basic block from the old function, and the list of blocks - oldBlocks.remove(*i); - - // Insert this basic block into the new function - newBlocks.push_back(*i); - } -} - -void -CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *codeReplacer, - Values &inputs, - Values &outputs) { - - // Emit a call to the new function, passing in: - // *pointer to struct (if aggregating parameters), or - // plan inputs and allocated memory for outputs +/// emitCallAndSwitchStatement - This method sets up the caller side by adding +/// the call instruction, splitting any PHI nodes in the header block as +/// necessary. +void CodeExtractor:: +emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, + Values &inputs, Values &outputs) { + // Emit a call to the new function, passing in: *pointer to struct (if + // aggregating parameters), or plan inputs and allocated memory for outputs std::vector params, StructValues, ReloadOutputs; // Add inputs as params, or to be filled into the struct @@ -462,6 +455,20 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } } +void CodeExtractor::moveCodeToFunction(Function *newFunction) { + Function *oldFunc = (*BlocksToExtract.begin())->getParent(); + Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); + Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); + + for (std::set::const_iterator i = BlocksToExtract.begin(), + e = BlocksToExtract.end(); i != e; ++i) { + // Delete the basic block from the old function, and the list of blocks + oldBlocks.remove(*i); + + // Insert this basic block into the new function + newBlocks.push_back(*i); + } +} /// ExtractRegion - Removes a loop from a function, replaces it with a call to /// new function. Returns pointer to the new function. @@ -497,6 +504,10 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector &code) // Assumption: this is a single-entry code region, and the header is the first // block in the region. BasicBlock *header = code[0]; + + // If we have to split PHI nodes, do so now. + severSplitPHINodes(header); + for (unsigned i = 1, e = code.size(); i != e; ++i) for (pred_iterator PI = pred_begin(code[i]), E = pred_end(code[i]); PI != E; ++PI) @@ -510,29 +521,14 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector &code) BasicBlock *codeReplacer = new BasicBlock("codeRepl", oldFunction); // The new function needs a root node because other nodes can branch to the - // head of the loop, and the root cannot have predecessors + // head of the region, but the entry node of a function cannot have preds. BasicBlock *newFuncRoot = new BasicBlock("newFuncRoot"); newFuncRoot->getInstList().push_back(new BranchInst(header)); - // Find inputs to, outputs from the code region - // - // If one of the inputs is coming from a different basic block and it's in a - // phi node, we need to rewrite the phi node: - // - // * All the inputs which involve basic blocks OUTSIDE of this region go into - // a NEW phi node that takes care of finding which value really came in. - // The result of this phi is passed to the function as an argument. - // - // * All the other phi values stay. - // - // FIXME: PHI nodes' incoming blocks aren't being rewritten to accomodate for - // blocks moving to a new function. - // SOLUTION: move Phi nodes out of the loop header into the codeReplacer, pass - // the values as parameters to the function - findInputsOutputs(inputs, outputs, codeReplacer, newFuncRoot); + // Find inputs to, outputs from the code region. + findInputsOutputs(inputs, outputs); - // Step 2: Construct new function based on inputs/outputs, - // Add allocas for all defs + // Construct new function based on inputs/outputs & add allocas for all defs. Function *newFunction = constructFunction(inputs, outputs, code[0], newFuncRoot, codeReplacer, oldFunction, @@ -568,13 +564,17 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector &code) } bool CodeExtractor::isEligible(const std::vector &code) { - // Deny code region if it contains allocas + // Deny code region if it contains allocas or vastarts. for (std::vector::const_iterator BB = code.begin(), e=code.end(); BB != e; ++BB) for (BasicBlock::const_iterator I = (*BB)->begin(), Ie = (*BB)->end(); I != Ie; ++I) if (isa(*I)) return false; + else if (const CallInst *CI = dyn_cast(I)) + if (const Function *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::vastart) + return false; return true; }