Simplify the code a bit by making the collection of basic blocks to extract

a member of the class.  While we're at it, turn the collection into a set
instead of a vector to improve efficiency and make queries simpler.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@12400 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Chris Lattner 2004-03-14 22:34:55 +00:00
parent 3b04a7a131
commit 0e06674287

View File

@ -26,16 +26,11 @@
#include "Support/Debug.h"
#include "Support/StringExtras.h"
#include <algorithm>
#include <map>
#include <vector>
#include <set>
using namespace llvm;
namespace {
inline bool contains(const std::vector<BasicBlock*> &V, const BasicBlock *BB){
return std::find(V.begin(), V.end(), BB) != V.end();
}
/// getFunctionArg - Return a pointer to F's ARGNOth argument.
///
Argument *getFunctionArg(Function *F, unsigned argno) {
@ -49,19 +44,16 @@ namespace {
typedef std::vector<std::pair<unsigned, unsigned> > PhiValChangesTy;
typedef std::map<PHINode*, PhiValChangesTy> PhiVal2ArgTy;
PhiVal2ArgTy PhiVal2Arg;
std::set<BasicBlock*> BlocksToExtract;
public:
Function *ExtractCodeRegion(const std::vector<BasicBlock*> &code);
private:
void findInputsOutputs(const std::vector<BasicBlock*> &code,
Values &inputs,
Values &outputs,
void findInputsOutputs(Values &inputs, Values &outputs,
BasicBlock *newHeader,
BasicBlock *newRootNode);
void processPhiNodeInputs(PHINode *Phi,
const std::vector<BasicBlock*> &code,
Values &inputs,
BasicBlock *newHeader,
BasicBlock *newRootNode);
@ -71,15 +63,12 @@ namespace {
Function *constructFunction(const Values &inputs,
const Values &outputs,
BasicBlock *newRootNode, BasicBlock *newHeader,
const std::vector<BasicBlock*> &code,
Function *oldFunction, Module *M);
void moveCodeToFunction(const std::vector<BasicBlock*> &code,
Function *newFunction);
void moveCodeToFunction(Function *newFunction);
void emitCallAndSwitchStatement(Function *newFunction,
BasicBlock *newHeader,
const std::vector<BasicBlock*> &code,
Values &inputs,
Values &outputs);
@ -87,7 +76,6 @@ namespace {
}
void CodeExtractor::processPhiNodeInputs(PHINode *Phi,
const std::vector<BasicBlock*> &code,
Values &inputs,
BasicBlock *codeReplacer,
BasicBlock *newFuncRoot)
@ -102,11 +90,11 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi,
for (unsigned i = 0, e = Phi->getNumIncomingValues(); i != e; ++i) {
Value *phiVal = Phi->getIncomingValue(i);
if (Instruction *Inst = dyn_cast<Instruction>(phiVal)) {
if (contains(code, Inst->getParent())) {
if (!contains(code, Phi->getIncomingBlock(i)))
if (BlocksToExtract.count(Inst->getParent())) {
if (!BlocksToExtract.count(Phi->getIncomingBlock(i)))
IValEBB.push_back(i);
} else {
if (contains(code, Phi->getIncomingBlock(i)))
if (BlocksToExtract.count(Phi->getIncomingBlock(i)))
EValIBB.push_back(i);
else
EValEBB.push_back(i);
@ -114,11 +102,11 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi,
} else if (Constant *Const = dyn_cast<Constant>(phiVal)) {
// Constants are internal, but considered `external' if they are coming
// from an external block.
if (!contains(code, Phi->getIncomingBlock(i)))
if (!BlocksToExtract.count(Phi->getIncomingBlock(i)))
EValEBB.push_back(i);
} else if (Argument *Arg = dyn_cast<Argument>(phiVal)) {
// arguments are external
if (contains(code, Phi->getIncomingBlock(i)))
if (BlocksToExtract.count(Phi->getIncomingBlock(i)))
EValIBB.push_back(i);
else
EValEBB.push_back(i);
@ -184,14 +172,13 @@ void CodeExtractor::processPhiNodeInputs(PHINode *Phi,
}
void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code,
Values &inputs,
void CodeExtractor::findInputsOutputs(Values &inputs,
Values &outputs,
BasicBlock *newHeader,
BasicBlock *newRootNode)
{
for (std::vector<BasicBlock*>::const_iterator ci = code.begin(),
ce = code.end(); ci != ce; ++ci) {
for (std::set<BasicBlock*>::const_iterator ci = BlocksToExtract.begin(),
ce = BlocksToExtract.end(); ci != ce; ++ci) {
BasicBlock *BB = *ci;
for (BasicBlock::iterator BBi = BB->begin(), BBe = BB->end();
BBi != BBe; ++BBi) {
@ -200,7 +187,7 @@ void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code,
if (Instruction *I = dyn_cast<Instruction>(&*BBi)) {
// If it's a phi node
if (PHINode *Phi = dyn_cast<PHINode>(I)) {
processPhiNodeInputs(Phi, code, inputs, newHeader, newRootNode);
processPhiNodeInputs(Phi, inputs, newHeader, newRootNode);
} else {
// All other instructions go through the generic input finder
// Loop over the operands of each instruction (inputs)
@ -208,7 +195,7 @@ void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code,
op != opE; ++op) {
if (Instruction *opI = dyn_cast<Instruction>(op->get())) {
// Check if definition of this operand is within the loop
if (!contains(code, opI->getParent())) {
if (!BlocksToExtract.count(opI->getParent())) {
// add this operand to the inputs
inputs.push_back(opI);
}
@ -220,7 +207,7 @@ void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code,
for (Value::use_iterator use = I->use_begin(), useE = I->use_end();
use != useE; ++use) {
if (Instruction* inst = dyn_cast<Instruction>(*use)) {
if (!contains(code, inst->getParent())) {
if (!BlocksToExtract.count(inst->getParent())) {
// add this op to the outputs
outputs.push_back(I);
}
@ -276,11 +263,10 @@ Function *CodeExtractor::constructFunction(const Values &inputs,
const Values &outputs,
BasicBlock *newRootNode,
BasicBlock *newHeader,
const std::vector<BasicBlock*> &code,
Function *oldFunction, Module *M) {
DEBUG(std::cerr << "inputs: " << inputs.size() << "\n");
DEBUG(std::cerr << "outputs: " << outputs.size() << "\n");
BasicBlock *header = code[0];
BasicBlock *header = *BlocksToExtract.begin();
// This function returns unsigned, outputs will go back by reference.
Type *retTy = Type::UShortTy;
@ -327,7 +313,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs,
for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end();
use != useE; ++use)
if (Instruction* inst = dyn_cast<Instruction>(*use))
if (contains(code, inst->getParent()))
if (BlocksToExtract.count(inst->getParent()))
inst->replaceUsesOfWith(inputs[i], getFunctionArg(newFunction, i));
}
@ -339,7 +325,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs,
i != e; ++i) {
if (BranchInst *inst = dyn_cast<BranchInst>(*i)) {
BasicBlock *BB = inst->getParent();
if (!contains(code, BB) && BB->getParent() == oldFunction) {
if (!BlocksToExtract.count(BB) && BB->getParent() == oldFunction) {
// The BasicBlock which contains the branch is not in the region
// modify the branch target to a new block
inst->replaceUsesOfWith(header, newHeader);
@ -350,29 +336,25 @@ Function *CodeExtractor::constructFunction(const Values &inputs,
return newFunction;
}
void CodeExtractor::moveCodeToFunction(const std::vector<BasicBlock*> &code,
Function *newFunction)
void CodeExtractor::moveCodeToFunction(Function *newFunction)
{
Function *oldFunc = code[0]->getParent();
Function *oldFunc = (*BlocksToExtract.begin())->getParent();
Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
for (std::vector<BasicBlock*>::const_iterator i = code.begin(), e =code.end();
i != e; ++i) {
BasicBlock *BB = *i;
for (std::set<BasicBlock*>::const_iterator i = BlocksToExtract.begin(),
e = BlocksToExtract.end(); i != e; ++i) {
// Delete the basic block from the old function, and the list of blocks
oldBlocks.remove(BB);
oldBlocks.remove(*i);
// Insert this basic block into the new function
newBlocks.push_back(BB);
newBlocks.push_back(*i);
}
}
void
CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
BasicBlock *codeReplacer,
const std::vector<BasicBlock*> &code,
Values &inputs,
Values &outputs)
{
@ -399,7 +381,7 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
for (std::vector<User*>::iterator use = Users.begin(), useE =Users.end();
use != useE; ++use) {
if (Instruction* inst = dyn_cast<Instruction>(*use)) {
if (!contains(code, inst->getParent())) {
if (!BlocksToExtract.count(inst->getParent())) {
inst->replaceUsesOfWith(*i, load);
}
}
@ -425,8 +407,8 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// Since there may be multiple exits from the original region, make the new
// function return an unsigned, switch on that number
unsigned switchVal = 0;
for (std::vector<BasicBlock*>::const_iterator i =code.begin(), e = code.end();
i != e; ++i) {
for (std::set<BasicBlock*>::const_iterator i = BlocksToExtract.begin(),
e = BlocksToExtract.end(); i != e; ++i) {
BasicBlock *BB = *i;
// rewrite the terminator of the original BasicBlock
@ -436,16 +418,14 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// Restore values just before we exit
// FIXME: Use a GetElementPtr to bunch the outputs in a struct
for (unsigned outIdx = 0, outE = outputs.size(); outIdx != outE; ++outIdx)
{
new StoreInst(outputs[outIdx],
getFunctionArg(newFunction, outIdx),
brInst);
}
// Rewrite branches into exits which return a value based on which
// exit we take from this function
if (brInst->isUnconditional()) {
if (!contains(code, brInst->getSuccessor(0))) {
if (!BlocksToExtract.count(brInst->getSuccessor(0))) {
ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal);
ReturnInst *newRet = new ReturnInst(brVal);
// add a new target to the switch
@ -461,7 +441,7 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// to two new blocks, each of which returns a different code.
for (unsigned idx = 0; idx < 2; ++idx) {
BasicBlock *oldTarget = brInst->getSuccessor(idx);
if (!contains(code, oldTarget)) {
if (!BlocksToExtract.count(oldTarget)) {
// add a new basic block which returns the appropriate value
BasicBlock *newTarget = new BasicBlock("newTarget", newFunction);
ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal);
@ -475,13 +455,15 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
}
}
}
} else if (SwitchInst *swTerm = dyn_cast<SwitchInst>(term)) {
assert(0 && "Cannot handle switch instructions just yet.");
} else if (ReturnInst *retTerm = dyn_cast<ReturnInst>(term)) {
assert(0 && "Cannot handle return instructions just yet.");
// FIXME: what if the terminator is a return!??!
// Need to rewrite: add new basic block, move the return there
// treat the original as an unconditional branch to that basicblock
} else if (SwitchInst *swTerm = dyn_cast<SwitchInst>(term)) {
assert(0 && "Cannot handle switch instructions just yet.");
} else if (InvokeInst *invInst = dyn_cast<InvokeInst>(term)) {
assert(0 && "Cannot handle invoke instructions just yet.");
} else {
@ -514,7 +496,8 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector<BasicBlock*> &code)
// * Add allocas for defs, pass as args by reference
// * Pass in uses as args
// 3) Move code region, add call instr to func
//
//
BlocksToExtract.insert(code.begin(), code.end());
Values inputs, outputs;
@ -548,19 +531,18 @@ Function *CodeExtractor::ExtractCodeRegion(const std::vector<BasicBlock*> &code)
// blocks moving to a new function.
// SOLUTION: move Phi nodes out of the loop header into the codeReplacer, pass
// the values as parameters to the function
findInputsOutputs(code, inputs, outputs, codeReplacer, newFuncRoot);
findInputsOutputs(inputs, outputs, codeReplacer, newFuncRoot);
// Step 2: Construct new function based on inputs/outputs,
// Add allocas for all defs
Function *newFunction = constructFunction(inputs, outputs, newFuncRoot,
codeReplacer, code,
oldFunction, module);
codeReplacer, oldFunction, module);
rewritePhiNodes(newFunction, newFuncRoot);
emitCallAndSwitchStatement(newFunction, codeReplacer, code, inputs, outputs);
emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
moveCodeToFunction(code, newFunction);
moveCodeToFunction(newFunction);
DEBUG(if (verifyFunction(*newFunction)) abort());
return newFunction;