From be04fdeb6c46e92fdeda7535c5912d072eff008c Mon Sep 17 00:00:00 2001 From: Nick Lewycky Date: Sun, 8 Aug 2010 05:04:23 +0000 Subject: [PATCH] Do more to modernize MergeFunctions. Refactor in response to Chris' code review. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@110538 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/IPO/MergeFunctions.cpp | 172 ++++++++++++-------------- 1 file changed, 81 insertions(+), 91 deletions(-) diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 21e9d6ad424..0b36204fc57 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -29,18 +29,17 @@ // // Many functions have their address taken by the virtual function table for // the object they belong to. However, as long as it's only used for a lookup -// and call, this is irrelevant, and we'd like to fold such implementations. +// and call, this is irrelevant, and we'd like to fold such functions. // // * switch from n^2 pair-wise comparisons to an n-way comparison for each // bucket. // -// * be smarter about bitcast. +// * be smarter about bitcasts. // // In order to fold functions, we will sometimes add either bitcast instructions // or bitcast constant expressions. Unfortunately, this can confound further // analysis since the two functions differ where one has a bitcast and the -// other doesn't. We should learn to peer through bitcasts without imposing bad -// performance properties. +// other doesn't. We should learn to look through bitcasts. // //===----------------------------------------------------------------------===// @@ -59,6 +58,7 @@ #include "llvm/Support/CallSite.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/IRBuilder.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetData.h" #include @@ -73,11 +73,28 @@ namespace { /// MergeFunctions will fold them by replacing a call to one to a call to a /// bitcast of the other. /// - struct MergeFunctions : public ModulePass { - static char ID; // Pass identification, replacement for typeid + class MergeFunctions : public ModulePass { + public: + static char ID; MergeFunctions() : ModulePass(ID) {} bool runOnModule(Module &M); + + private: + /// PairwiseCompareAndMerge - Given a list of functions, compare each pair + /// and merge the pairs of equivalent functions. + bool PairwiseCompareAndMerge(std::vector &FnVec); + + /// MergeTwoFunctions - Merge two equivalent functions. Upon completion, + /// FnVec[j] should never be visited again. + void MergeTwoFunctions(std::vector &FnVec, + unsigned i, unsigned j) const; + + /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also + /// replace direct uses of G with bitcast(F). + void WriteThunk(Function *F, Function *G) const; + + TargetData *TD; }; } @@ -88,42 +105,42 @@ ModulePass *llvm::createMergeFunctionsPass() { return new MergeFunctions(); } -// ===----------------------------------------------------------------------=== -// Comparison of functions -// ===----------------------------------------------------------------------=== namespace { +/// FunctionComparator - Compares two functions to determine whether or not +/// they will generate machine code with the same behaviour. TargetData is +/// used if available. The comparator always fails conservatively (erring on the +/// side of claiming that two functions are different). class FunctionComparator { public: FunctionComparator(TargetData *TD, Function *F1, Function *F2) : F1(F1), F2(F2), TD(TD), IDMap1Count(0), IDMap2Count(0) {} - // Compare - test whether the two functions have equivalent behaviour. + /// Compare - test whether the two functions have equivalent behaviour. bool Compare(); private: - // Compare - test whether two basic blocks have equivalent behaviour. + /// Compare - test whether two basic blocks have equivalent behaviour. bool Compare(const BasicBlock *BB1, const BasicBlock *BB2); - // Enumerate - Assign or look up previously assigned numbers for the two - // values, and return whether the numbers are equal. Numbers are assigned in - // the order visited. + /// Enumerate - Assign or look up previously assigned numbers for the two + /// values, and return whether the numbers are equal. Numbers are assigned in + /// the order visited. bool Enumerate(const Value *V1, const Value *V2); - // isEquivalentOperation - Compare two Instructions for equivalence, similar - // to Instruction::isSameOperationAs but with modifications to the type - // comparison. + /// isEquivalentOperation - Compare two Instructions for equivalence, similar + /// to Instruction::isSameOperationAs but with modifications to the type + /// comparison. bool isEquivalentOperation(const Instruction *I1, const Instruction *I2) const; - // isEquivalentGEP - Compare two GEPs for equivalent pointer arithmetic. + /// isEquivalentGEP - Compare two GEPs for equivalent pointer arithmetic. bool isEquivalentGEP(const GEPOperator *GEP1, const GEPOperator *GEP2); - bool isEquivalentGEP(const GetElementPtrInst *GEP1, - const GetElementPtrInst *GEP2) { + const GetElementPtrInst *GEP2) { return isEquivalentGEP(cast(GEP1), cast(GEP2)); } - // isEquivalentType - Compare two Types, treating all pointer types as equal. + /// isEquivalentType - Compare two Types, treating all pointer types as equal. bool isEquivalentType(const Type *Ty1, const Type *Ty2) const; // The two functions undergoing comparison. @@ -137,9 +154,8 @@ private: }; } -/// Compute a number which is guaranteed to be equal for two equivalent -/// functions, but is very likely to be different for different functions. This -/// needs to be computed as efficiently as possible. +/// Compute a hash guaranteed to be equal for two equivalent functions, but +/// very likely to be different for different functions. static unsigned long ProfileFunction(const Function *F) { const FunctionType *FTy = F->getFunctionType(); @@ -208,7 +224,6 @@ bool FunctionComparator::isEquivalentType(const Type *Ty1, const UnionType *UTy1 = cast(Ty1); const UnionType *UTy2 = cast(Ty2); - // TODO: we could be fancy with union(A, union(A, B)) === union(A, B), etc. if (UTy1->getNumElements() != UTy2->getNumElements()) return false; @@ -373,7 +388,7 @@ bool FunctionComparator::Enumerate(const Value *V1, const Value *V2) { return ID1 == ID2; } -// Compare - test whether two basic blocks have equivalent behaviour. +/// Compare - test whether two basic blocks have equivalent behaviour. bool FunctionComparator::Compare(const BasicBlock *BB1, const BasicBlock *BB2) { BasicBlock::const_iterator F1I = BB1->begin(), F1E = BB1->end(); BasicBlock::const_iterator F2I = BB2->begin(), F2E = BB2->end(); @@ -416,6 +431,7 @@ bool FunctionComparator::Compare(const BasicBlock *BB1, const BasicBlock *BB2) { return F1I == F1E && F2I == F2E; } +/// Compare - test whether the two functions have equivalent behaviour. bool FunctionComparator::Compare() { // We need to recheck everything, but check the things that weren't included // in the hash first. @@ -457,8 +473,8 @@ bool FunctionComparator::Compare() { llvm_unreachable("Arguments repeat"); } - // We need to do an ordered walk since the actual ordering of the blocks in - // the linked list is immaterial. Our walk starts at the entry block for both + // We do a CFG-ordered walk since the actual ordering of the blocks in the + // linked list is immaterial. Our walk starts at the entry block for both // functions, then takes each block from each terminator in order. As an // artifact, this also means that unreachable blocks are ignored. SmallVector F1BBs, F2BBs; @@ -490,31 +506,9 @@ bool FunctionComparator::Compare() { return true; } -// ===----------------------------------------------------------------------=== -// Folding of functions -// ===----------------------------------------------------------------------=== - -// Cases: -// * F is external strong, G is external strong: -// turn G into a thunk to F -// * F is external strong, G is external weak: -// turn G into a thunk to F -// * F is external weak, G is external weak: -// unfoldable -// * F is external strong, G is internal: -// turn G into a thunk to F -// * F is internal, G is external weak -// turn G into a thunk to F -// * F is internal, G is internal: -// turn G into a thunk to F -// -// external means 'externally visible' linkage != (internal,private) -// internal means linkage == (internal,private) -// weak means linkage mayBeOverridable - -/// ThunkGToF - Replace G with a simple tail call to bitcast(F). Also replace +/// WriteThunk - Replace G with a simple tail call to bitcast(F). Also replace /// direct uses of G with bitcast(F). -static void ThunkGToF(Function *F, Function *G) { +void MergeFunctions::WriteThunk(Function *F, Function *G) const { if (!G->mayBeOverridden()) { // Redirect direct callers of G to F. Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); @@ -538,30 +532,24 @@ static void ThunkGToF(Function *F, Function *G) { Function *NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "", G->getParent()); BasicBlock *BB = BasicBlock::Create(F->getContext(), "", NewG); + IRBuilder Builder(BB); SmallVector Args; unsigned i = 0; const FunctionType *FFTy = F->getFunctionType(); for (Function::arg_iterator AI = NewG->arg_begin(), AE = NewG->arg_end(); AI != AE; ++AI) { - if (FFTy->getParamType(i) == AI->getType()) { - Args.push_back(AI); - } else { - Args.push_back(new BitCastInst(AI, FFTy->getParamType(i), "", BB)); - } + Args.push_back(Builder.CreateBitCast(AI, FFTy->getParamType(i))); ++i; } - CallInst *CI = CallInst::Create(F, Args.begin(), Args.end(), "", BB); + CallInst *CI = Builder.CreateCall(F, Args.begin(), Args.end()); CI->setTailCall(); CI->setCallingConv(F->getCallingConv()); if (NewG->getReturnType()->isVoidTy()) { - ReturnInst::Create(F->getContext(), BB); - } else if (CI->getType() != NewG->getReturnType()) { - Value *BCI = new BitCastInst(CI, NewG->getReturnType(), "", BB); - ReturnInst::Create(F->getContext(), BCI, BB); + Builder.CreateRetVoid(); } else { - ReturnInst::Create(F->getContext(), CI, BB); + Builder.CreateRet(Builder.CreateBitCast(CI, NewG->getReturnType())); } NewG->copyAttributesFrom(G); @@ -570,7 +558,10 @@ static void ThunkGToF(Function *F, Function *G) { G->eraseFromParent(); } -static bool fold(std::vector &FnVec, unsigned i, unsigned j) { +/// MergeTwoFunctions - Merge two equivalent functions. Upon completion, +/// FnVec[j] should never be visited again. +void MergeFunctions::MergeTwoFunctions(std::vector &FnVec, + unsigned i, unsigned j) const { Function *F = FnVec[i]; Function *G = FnVec[j]; @@ -589,22 +580,39 @@ static bool fold(std::vector &FnVec, unsigned i, unsigned j) { H->takeName(F); F->replaceAllUsesWith(H); - ThunkGToF(F, G); - ThunkGToF(F, H); + WriteThunk(F, G); + WriteThunk(F, H); F->setAlignment(std::max(G->getAlignment(), H->getAlignment())); F->setLinkage(GlobalValue::InternalLinkage); } else { - ThunkGToF(F, G); + WriteThunk(F, G); } ++NumFunctionsMerged; - return true; } -// ===----------------------------------------------------------------------=== -// Pass definition -// ===----------------------------------------------------------------------=== +/// PairwiseCompareAndMerge - Given a list of functions, compare each pair and +/// merge the pairs of equivalent functions. +bool MergeFunctions::PairwiseCompareAndMerge(std::vector &FnVec) { + bool Changed = false; + for (int i = 0, e = FnVec.size(); i != e; ++i) { + for (int j = i + 1; j != e; ++j) { + bool isEqual = FunctionComparator(TD, FnVec[i], FnVec[j]).Compare(); + + DEBUG(dbgs() << " " << FnVec[i]->getName() + << (isEqual ? " == " : " != ") << FnVec[j]->getName() << "\n"); + + if (isEqual) { + MergeTwoFunctions(FnVec, i, j); + Changed = true; + FnVec.erase(FnVec.begin() + j); + --j, --e; + } + } + } + return Changed; +} bool MergeFunctions::runOnModule(Module &M) { bool Changed = false; @@ -618,7 +626,7 @@ bool MergeFunctions::runOnModule(Module &M) { FnMap[ProfileFunction(F)].push_back(F); } - TargetData *TD = getAnalysisIfAvailable(); + TD = getAnalysisIfAvailable(); bool LocalChanged; do { @@ -628,25 +636,7 @@ bool MergeFunctions::runOnModule(Module &M) { I = FnMap.begin(), E = FnMap.end(); I != E; ++I) { std::vector &FnVec = I->second; DEBUG(dbgs() << "hash (" << I->first << "): " << FnVec.size() << "\n"); - - for (int i = 0, e = FnVec.size(); i != e; ++i) { - for (int j = i + 1; j != e; ++j) { - bool isEqual = FunctionComparator(TD, FnVec[i], FnVec[j]).Compare(); - - DEBUG(dbgs() << " " << FnVec[i]->getName() - << (isEqual ? " == " : " != ") - << FnVec[j]->getName() << "\n"); - - if (isEqual) { - if (fold(FnVec, i, j)) { - LocalChanged = true; - FnVec.erase(FnVec.begin() + j); - --j, --e; - } - } - } - } - + LocalChanged |= PairwiseCompareAndMerge(FnVec); } Changed |= LocalChanged; } while (LocalChanged);