diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 5d838f98aa0..726466ec81d 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -69,32 +69,79 @@ using namespace llvm; STATISTIC(NumFunctionsMerged, "Number of functions merged"); namespace { - /// MergeFunctions finds functions which will generate identical machine code, - /// by considering all pointer types to be equivalent. Once identified, - /// MergeFunctions will fold them by replacing a call to one to a call to a - /// bitcast of the other. - /// - class MergeFunctions : public ModulePass { - public: - static char ID; - MergeFunctions() : ModulePass(ID) {} - bool runOnModule(Module &M); +static unsigned ProfileFunction(const Function *F) { + const FunctionType *FTy = F->getFunctionType(); - private: - /// MergeTwoFunctions - Merge two equivalent functions. Upon completion, G - /// may be deleted, or may be converted into a thunk. In either case, it - /// should never be visited again. - void MergeTwoFunctions(Function *F, Function *G) const; - - /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also - /// replace direct uses of G with bitcast(F). - void WriteThunk(Function *F, Function *G) const; - - TargetData *TD; - }; + FoldingSetNodeID ID; + ID.AddInteger(F->size()); + ID.AddInteger(F->getCallingConv()); + ID.AddBoolean(F->hasGC()); + ID.AddBoolean(FTy->isVarArg()); + ID.AddInteger(FTy->getReturnType()->getTypeID()); + for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) + ID.AddInteger(FTy->getParamType(i)->getTypeID()); + return ID.ComputeHash(); } +class ComparableFunction { +public: + ComparableFunction(Function *Func, TargetData *TD) + : Func(Func), Hash(ProfileFunction(Func)), TD(TD) {} + + AssertingVH const Func; + const unsigned Hash; + TargetData * const TD; +}; + +struct MergeFunctionsEqualityInfo { + static ComparableFunction *getEmptyKey() { + return reinterpret_cast(0); + } + static ComparableFunction *getTombstoneKey() { + return reinterpret_cast(-1); + } + static unsigned getHashValue(const ComparableFunction *CF) { + return CF->Hash; + } + static bool isEqual(const ComparableFunction *LHS, + const ComparableFunction *RHS); +}; + +/// MergeFunctions finds functions which will generate identical machine code, +/// by considering all pointer types to be equivalent. Once identified, +/// MergeFunctions will fold them by replacing a call to one to a call to a +/// bitcast of the other. +/// +class MergeFunctions : public ModulePass { +public: + static char ID; + MergeFunctions() : ModulePass(ID) {} + + bool runOnModule(Module &M); + +private: + typedef DenseSet FnSetType; + + + /// Insert a ComparableFunction into the FnSet, or merge it away if it's + /// equal to one that's already present. + bool Insert(FnSetType &FnSet, ComparableFunction *NewF); + + /// MergeTwoFunctions - Merge two equivalent functions. Upon completion, G + /// may be deleted, or may be converted into a thunk. In either case, it + /// should never be visited again. + void MergeTwoFunctions(Function *F, Function *G) const; + + /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also + /// replace direct uses of G with bitcast(F). Deletes G. + void WriteThunk(Function *F, Function *G) const; + + TargetData *TD; +}; + +} // end anonymous namespace + char MergeFunctions::ID = 0; INITIALIZE_PASS(MergeFunctions, "mergefunc", "Merge Functions", false, false); @@ -475,7 +522,7 @@ bool FunctionComparator::Compare() { } /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also replace -/// direct uses of G with bitcast(F). +/// direct uses of G with bitcast(F). Deletes G. void MergeFunctions::WriteThunk(Function *F, Function *G) const { if (!G->mayBeOverridden()) { // Redirect direct callers of G to F. @@ -553,100 +600,138 @@ void MergeFunctions::MergeTwoFunctions(Function *F, Function *G) const { ++NumFunctionsMerged; } -static unsigned ProfileFunction(const Function *F) { - const FunctionType *FTy = F->getFunctionType(); +// Insert - Insert a ComparableFunction into the FnSet, or merge it away if +// equal to one that's already inserted. +bool MergeFunctions::Insert(FnSetType &FnSet, ComparableFunction *NewF) { + std::pair Result = FnSet.insert(NewF); + if (Result.second) + return false; - FoldingSetNodeID ID; - ID.AddInteger(F->size()); - ID.AddInteger(F->getCallingConv()); - ID.AddBoolean(F->hasGC()); - ID.AddBoolean(FTy->isVarArg()); - ID.AddInteger(FTy->getReturnType()->getTypeID()); - for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) - ID.AddInteger(FTy->getParamType(i)->getTypeID()); - return ID.ComputeHash(); + ComparableFunction *OldF = *Result.first; + assert(OldF && "Expected a hash collision"); + + // Never thunk a strong function to a weak function. + assert(!OldF->Func->isWeakForLinker() || NewF->Func->isWeakForLinker()); + + DEBUG(dbgs() << " " << OldF->Func->getName() << " == " + << NewF->Func->getName() << '\n'); + + Function *DeleteF = NewF->Func; + delete NewF; + MergeTwoFunctions(OldF->Func, DeleteF); + return true; } -class ComparableFunction { -public: - ComparableFunction(Function *Func, TargetData *TD) - : Func(Func), Hash(ProfileFunction(Func)), TD(TD) {} +// IsThunk - This method determines whether or not a given Function is a thunk\// like the ones emitted by this pass and therefore not subject to further +// merging. +static bool IsThunk(const Function *F) { + // The safe direction to fail is to return true. In that case, the function + // will be removed from merging analysis. If we failed to including functions + // then we may try to merge unmergable thing (ie., identical weak functions) + // which will push us into an infinite loop. - AssertingVH const Func; - const unsigned Hash; - TargetData * const TD; -}; + if (F->size() != 1) + return false; -struct MergeFunctionsEqualityInfo { - static ComparableFunction *getEmptyKey() { - return reinterpret_cast(0); + const BasicBlock *BB = &F->front(); + // A thunk is: + // bitcast-inst* + // optional-reg tail call @thunkee(args...*) + // ret void|optional-reg + // where the args are in the same order as the arguments. + + // Verify that the sequence of bitcast-inst's are all casts of arguments and + // that there aren't any extras (ie. no repeated casts). + int LastArgNo = -1; + BasicBlock::const_iterator I = BB->begin(); + while (const BitCastInst *BCI = dyn_cast(I)) { + const Argument *A = dyn_cast(BCI->getOperand(0)); + if (!A) return false; + if ((int)A->getArgNo() >= LastArgNo) return false; + LastArgNo = A->getArgNo(); + ++I; } - static ComparableFunction *getTombstoneKey() { - return reinterpret_cast(-1); + + // Verify that the call instruction has the same arguments as this function + // and that they're all either the incoming argument or a cast of the right + // argument. + const CallInst *CI = dyn_cast(I++); + if (!CI || !CI->isTailCall() || + CI->getNumArgOperands() != F->arg_size()) return false; + + for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { + const Value *V = CI->getArgOperand(i); + const Argument *A = dyn_cast(V); + if (!A) { + const BitCastInst *BCI = dyn_cast(V); + if (!BCI) return false; + A = cast(BCI->getOperand(0)); + } + if (A->getArgNo() != i) return false; } - static unsigned getHashValue(const ComparableFunction *CF) { - return CF->Hash; + + // Verify that the terminator is a ret void (if we're void) or a ret of the + // call's return, or a ret of a bitcast of the call's return. + const Value *RetOp = CI; + if (const BitCastInst *BCI = dyn_cast(I)) { + ++I; + if (BCI->getOperand(0) != CI) return false; + RetOp = BCI; } - static bool isEqual(const ComparableFunction *LHS, - const ComparableFunction *RHS) { - if (LHS == RHS) - return true; - if (LHS == getEmptyKey() || LHS == getTombstoneKey() || - RHS == getEmptyKey() || RHS == getTombstoneKey()) - return false; - assert(LHS->TD == RHS->TD && "Comparing functions for different targets"); - return FunctionComparator(LHS->TD, LHS->Func, RHS->Func).Compare(); - } -}; + const ReturnInst *RI = dyn_cast(I); + if (!RI) return false; + if (RI->getNumOperands() == 0) + return CI->getType()->isVoidTy(); + return RI->getReturnValue() == CI; +} bool MergeFunctions::runOnModule(Module &M) { - typedef DenseSet FnSetType; - bool Changed = false; TD = getAnalysisIfAvailable(); - std::vector Funcs; - for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { - if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) - Funcs.push_back(F); - } - bool LocalChanged; do { + DEBUG(dbgs() << "size: " << M.size() << '\n'); LocalChanged = false; - FnSetType FnSet; - for (unsigned i = 0, e = Funcs.size(); i != e;) { - Function *F = Funcs[i]; - ComparableFunction *NewF = new ComparableFunction(F, TD); - std::pair Result = FnSet.insert(NewF); - if (!Result.second) { - ComparableFunction *&OldF = *Result.first; - assert(OldF && "Expected a hash collision"); - // NewF will be deleted in favour of OldF unless NewF is strong and - // OldF is weak in which case swap them to keep the strong definition. + // Insert only strong functions and merge them. Strong function merging + // always deletes one of them. + for (Module::iterator I = M.begin(), E = M.end(); I != E;) { + Function *F = I++; + if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() && + !F->isWeakForLinker() && !IsThunk(F)) { + ComparableFunction *CF = new ComparableFunction(F, TD); + LocalChanged |= Insert(FnSet, CF); + } + } - if (OldF->Func->isWeakForLinker() && !NewF->Func->isWeakForLinker()) - std::swap(OldF, NewF); - - DEBUG(dbgs() << " " << OldF->Func->getName() << " == " - << NewF->Func->getName() << '\n'); - - Funcs.erase(Funcs.begin() + i); - --e; - - Function *DeleteF = NewF->Func; - delete NewF; - MergeTwoFunctions(OldF->Func, DeleteF); - LocalChanged = true; - Changed = true; - } else { - ++i; + // Insert only weak functions and merge them. By doing these second we + // create thunks to the strong function when possible. When two weak + // functions are identical, we create a new strong function with two weak + // weak thunks to it which are identical but not mergable. + for (Module::iterator I = M.begin(), E = M.end(); I != E;) { + Function *F = I++; + if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() && + F->isWeakForLinker() && !IsThunk(F)) { + ComparableFunction *CF = new ComparableFunction(F, TD); + LocalChanged |= Insert(FnSet, CF); } } DeleteContainerPointers(FnSet); + Changed |= LocalChanged; } while (LocalChanged); return Changed; } + +bool MergeFunctionsEqualityInfo::isEqual(const ComparableFunction *LHS, + const ComparableFunction *RHS) { + if (LHS == RHS) + return true; + if (LHS == getEmptyKey() || LHS == getTombstoneKey() || + RHS == getEmptyKey() || RHS == getTombstoneKey()) + return false; + assert(LHS->TD == RHS->TD && "Comparing functions for different targets"); + return FunctionComparator(LHS->TD, LHS->Func, RHS->Func).Compare(); +}