diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 17bc2d41a4c..16083ec7817 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -9,10 +9,6 @@ // // This pass looks for equivalent functions that are mergable and folds them. // -// A Function will not be analyzed if: -// * it is overridable at runtime (except for weak linkage), or -// * it is used by anything other than the callee parameter of a call/invoke -// // A hash is computed from the function, based on its type and number of // basic blocks. // @@ -24,8 +20,6 @@ // When a match is found, the functions are folded. We can only fold two // functions when we know that the definition of one of them is not // overridable. -// * fold a function marked internal by replacing all of its users. -// * fold extern or weak functions by replacing them with a global alias // //===----------------------------------------------------------------------===// // @@ -48,6 +42,7 @@ #define DEBUG_TYPE "mergefunc" #include "llvm/Transforms/IPO.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/FoldingSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Constants.h" #include "llvm/InlineAsm.h" @@ -62,7 +57,6 @@ using namespace llvm; STATISTIC(NumFunctionsMerged, "Number of functions merged"); -STATISTIC(NumMergeFails, "Number of identical function pairings not merged"); namespace { struct VISIBILITY_HIDDEN MergeFunctions : public ModulePass { @@ -81,16 +75,168 @@ ModulePass *llvm::createMergeFunctionsPass() { return new MergeFunctions(); } +// ===----------------------------------------------------------------------=== +// Comparison of functions +// ===----------------------------------------------------------------------=== + static unsigned long hash(const Function *F) { - return F->size() ^ reinterpret_cast(F->getType()); - //return F->size() ^ F->arg_size() ^ F->getReturnType(); + const FunctionType *FTy = F->getFunctionType(); + + FoldingSetNodeID ID; + ID.AddInteger(F->size()); + ID.AddInteger(F->getCallingConv()); + ID.AddBoolean(F->hasGC()); + ID.AddBoolean(FTy->isVarArg()); + ID.AddInteger(FTy->getReturnType()->getTypeID()); + for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) + ID.AddInteger(FTy->getParamType(i)->getTypeID()); + return ID.ComputeHash(); +} + +/// IgnoreBitcasts - given a bitcast, returns the first non-bitcast found by +/// walking the chain of cast operands. Otherwise, returns the argument. +static Value* IgnoreBitcasts(Value *V) { + while (BitCastInst *BC = dyn_cast(V)) + V = BC->getOperand(0); + + return V; +} + +/// isEquivalentType - any two pointers are equivalent. Otherwise, standard +/// type equivalence rules apply. +static bool isEquivalentType(const Type *Ty1, const Type *Ty2) { + if (Ty1 == Ty2) + return true; + if (Ty1->getTypeID() != Ty2->getTypeID()) + return false; + + switch(Ty1->getTypeID()) { + case Type::VoidTyID: + case Type::FloatTyID: + case Type::DoubleTyID: + case Type::X86_FP80TyID: + case Type::FP128TyID: + case Type::PPC_FP128TyID: + case Type::LabelTyID: + case Type::MetadataTyID: + return true; + + case Type::IntegerTyID: + case Type::OpaqueTyID: + // Ty1 == Ty2 would have returned true earlier. + return false; + + default: + assert(0 && "Unknown type!"); + return false; + + case Type::PointerTyID: { + const PointerType *PTy1 = cast(Ty1); + const PointerType *PTy2 = cast(Ty2); + return PTy1->getAddressSpace() == PTy2->getAddressSpace(); + } + + case Type::StructTyID: { + const StructType *STy1 = cast(Ty1); + const StructType *STy2 = cast(Ty2); + if (STy1->getNumElements() != STy2->getNumElements()) + return false; + + if (STy1->isPacked() != STy2->isPacked()) + return false; + + for (unsigned i = 0, e = STy1->getNumElements(); i != e; ++i) { + if (!isEquivalentType(STy1->getElementType(i), STy2->getElementType(i))) + return false; + } + return true; + } + + case Type::FunctionTyID: { + const FunctionType *FTy1 = cast(Ty1); + const FunctionType *FTy2 = cast(Ty2); + if (FTy1->getNumParams() != FTy2->getNumParams() || + FTy1->isVarArg() != FTy2->isVarArg()) + return false; + + if (!isEquivalentType(FTy1->getReturnType(), FTy2->getReturnType())) + return false; + + for (unsigned i = 0, e = FTy1->getNumParams(); i != e; ++i) { + if (!isEquivalentType(FTy1->getParamType(i), FTy2->getParamType(i))) + return false; + } + return true; + } + + case Type::ArrayTyID: + case Type::VectorTyID: { + const SequentialType *STy1 = cast(Ty1); + const SequentialType *STy2 = cast(Ty2); + return isEquivalentType(STy1->getElementType(), STy2->getElementType()); + } + } +} + +/// isEquivalentOperation - determine whether the two operations are the same +/// except that pointer-to-A and pointer-to-B are equivalent. This should be +/// kept in sync with Instruction::isSameOperandAs. +static bool isEquivalentOperation(const Instruction *I1, const Instruction *I2) { + if (I1->getOpcode() != I2->getOpcode() || + I1->getNumOperands() != I2->getNumOperands() || + !isEquivalentType(I1->getType(), I2->getType())) + return false; + + // We have two instructions of identical opcode and #operands. Check to see + // if all operands are the same type + for (unsigned i = 0, e = I1->getNumOperands(); i != e; ++i) + if (!isEquivalentType(I1->getOperand(i)->getType(), + I2->getOperand(i)->getType())) + return false; + + // Check special state that is a part of some instructions. + if (const LoadInst *LI = dyn_cast(I1)) + return LI->isVolatile() == cast(I2)->isVolatile() && + LI->getAlignment() == cast(I2)->getAlignment(); + if (const StoreInst *SI = dyn_cast(I1)) + return SI->isVolatile() == cast(I2)->isVolatile() && + SI->getAlignment() == cast(I2)->getAlignment(); + if (const CmpInst *CI = dyn_cast(I1)) + return CI->getPredicate() == cast(I2)->getPredicate(); + if (const CallInst *CI = dyn_cast(I1)) + return CI->isTailCall() == cast(I2)->isTailCall() && + CI->getCallingConv() == cast(I2)->getCallingConv() && + CI->getAttributes().getRawPointer() == + cast(I2)->getAttributes().getRawPointer(); + if (const InvokeInst *CI = dyn_cast(I1)) + return CI->getCallingConv() == cast(I2)->getCallingConv() && + CI->getAttributes().getRawPointer() == + cast(I2)->getAttributes().getRawPointer(); + if (const InsertValueInst *IVI = dyn_cast(I1)) { + if (IVI->getNumIndices() != cast(I2)->getNumIndices()) + return false; + for (unsigned i = 0, e = IVI->getNumIndices(); i != e; ++i) + if (IVI->idx_begin()[i] != cast(I2)->idx_begin()[i]) + return false; + return true; + } + if (const ExtractValueInst *EVI = dyn_cast(I1)) { + if (EVI->getNumIndices() != cast(I2)->getNumIndices()) + return false; + for (unsigned i = 0, e = EVI->getNumIndices(); i != e; ++i) + if (EVI->idx_begin()[i] != cast(I2)->idx_begin()[i]) + return false; + return true; + } + + return true; } static bool compare(const Value *V, const Value *U) { assert(!isa(V) && !isa(U) && "Must not compare basic blocks."); - assert(V->getType() == U->getType() && + assert(isEquivalentType(V->getType(), U->getType()) && "Two of the same operation have operands of different type."); // TODO: If the constant is an expression of F, we should accept that it's @@ -117,18 +263,24 @@ static bool compare(const Value *V, const Value *U) { static bool equals(const BasicBlock *BB1, const BasicBlock *BB2, DenseMap &ValueMap, DenseMap &SpeculationMap) { - // Specutively add it anyways. If it's false, we'll notice a difference later, and - // this won't matter. + // Speculatively add it anyways. If it's false, we'll notice a difference + // later, and this won't matter. ValueMap[BB1] = BB2; BasicBlock::const_iterator FI = BB1->begin(), FE = BB1->end(); BasicBlock::const_iterator GI = BB2->begin(), GE = BB2->end(); do { - if (!FI->isSameOperationAs(const_cast(&*GI))) - return false; + if (isa(FI)) { + ++FI; + continue; + } + if (isa(GI)) { + ++GI; + continue; + } - if (FI->getNumOperands() != GI->getNumOperands()) + if (!isEquivalentOperation(FI, GI)) return false; if (ValueMap[FI] == GI) { @@ -140,8 +292,8 @@ static bool equals(const BasicBlock *BB1, const BasicBlock *BB2, return false; for (unsigned i = 0, e = FI->getNumOperands(); i != e; ++i) { - Value *OpF = FI->getOperand(i); - Value *OpG = GI->getOperand(i); + Value *OpF = IgnoreBitcasts(FI->getOperand(i)); + Value *OpG = IgnoreBitcasts(GI->getOperand(i)); if (ValueMap[OpF] == OpG) continue; @@ -149,10 +301,8 @@ static bool equals(const BasicBlock *BB1, const BasicBlock *BB2, if (ValueMap[OpF] != NULL) return false; - assert(OpF->getType() == OpG->getType() && - "Two of the same operation has operands of different type."); - - if (OpF->getValueID() != OpG->getValueID()) + if (OpF->getValueID() != OpG->getValueID() || + !isEquivalentType(OpF->getType(), OpG->getType())) return false; if (isa(FI)) { @@ -203,14 +353,15 @@ static bool equals(const Function *F, const Function *G) { if (F->hasSection() && F->getSection() != G->getSection()) return false; + if (F->isVarArg() != G->isVarArg()) + return false; + // TODO: if it's internal and only used in direct calls, we could handle this // case too. if (F->getCallingConv() != G->getCallingConv()) return false; - // TODO: We want to permit cases where two functions take T* and S* but - // only load or store them into T** and S**. - if (F->getType() != G->getType()) + if (!isEquivalentType(F->getFunctionType(), G->getFunctionType())) return false; DenseMap ValueMap; @@ -237,88 +388,198 @@ static bool equals(const Function *F, const Function *G) { return true; } -static bool fold(std::vector &FnVec, unsigned i, unsigned j) { - if (FnVec[i]->mayBeOverridden() && !FnVec[j]->mayBeOverridden()) - std::swap(FnVec[i], FnVec[j]); +// ===----------------------------------------------------------------------=== +// Folding of functions +// ===----------------------------------------------------------------------=== +// Cases: +// * F is external strong, G is external strong: +// turn G into a thunk to F (1) +// * F is external strong, G is external weak: +// turn G into a thunk to F (1) +// * F is external weak, G is external weak: +// unfoldable +// * F is external strong, G is internal: +// address of G taken: +// turn G into a thunk to F (1) +// address of G not taken: +// make G an alias to F (2) +// * F is internal, G is external weak +// address of F is taken: +// turn G into a thunk to F (1) +// address of F is not taken: +// make G an alias of F (2) +// * F is internal, G is internal: +// address of F and G are taken: +// turn G into a thunk to F (1) +// address of G is not taken: +// make G an alias to F (2) +// +// alias requires linkage == (external,local,weak) fallback to creating a thunk +// external means 'externally visible' linkage != (internal,private) +// internal means linkage == (internal,private) +// weak means linkage mayBeOverridable +// being external implies that the address is taken +// +// 1. turn G into a thunk to F +// 2. make G an alias to F + +enum LinkageCategory { + ExternalStrong, + ExternalWeak, + Internal +}; + +static LinkageCategory categorize(const Function *F) { + switch (F->getLinkage()) { + case GlobalValue::InternalLinkage: + case GlobalValue::PrivateLinkage: + return Internal; + + case GlobalValue::WeakAnyLinkage: + case GlobalValue::WeakODRLinkage: + case GlobalValue::ExternalWeakLinkage: + return ExternalWeak; + + case GlobalValue::ExternalLinkage: + case GlobalValue::AvailableExternallyLinkage: + case GlobalValue::LinkOnceAnyLinkage: + case GlobalValue::LinkOnceODRLinkage: + case GlobalValue::AppendingLinkage: + case GlobalValue::DLLImportLinkage: + case GlobalValue::DLLExportLinkage: + case GlobalValue::GhostLinkage: + case GlobalValue::CommonLinkage: + return ExternalStrong; + } + + assert(0 && "Unknown LinkageType."); + return ExternalWeak; +} + +static void ThunkGToF(Function *F, Function *G) { + Function *NewG = Function::Create(G->getFunctionType(), G->getLinkage(), + "", G->getParent()); + BasicBlock *BB = BasicBlock::Create("", NewG); + + std::vector Args; + unsigned i = 0; + const FunctionType *FFTy = F->getFunctionType(); + for (Function::arg_iterator AI = NewG->arg_begin(), AE = NewG->arg_end(); + AI != AE; ++AI) { + if (FFTy->getParamType(i) == AI->getType()) + Args.push_back(AI); + else { + Value *BCI = new BitCastInst(AI, FFTy->getParamType(i), "", BB); + Args.push_back(BCI); + } + ++i; + } + + CallInst *CI = CallInst::Create(F, Args.begin(), Args.end(), "", BB); + CI->setTailCall(); + if (NewG->getReturnType() == Type::VoidTy) { + ReturnInst::Create(BB); + } else if (CI->getType() != NewG->getReturnType()) { + Value *BCI = new BitCastInst(CI, NewG->getReturnType(), "", BB); + ReturnInst::Create(BCI, BB); + } else { + ReturnInst::Create(CI, BB); + } + + NewG->copyAttributesFrom(G); + NewG->takeName(G); + G->replaceAllUsesWith(NewG); + G->eraseFromParent(); + + // TODO: look at direct callers to G and make them all direct callers to F + // iff G->hasAddressTaken() is false. +} + +static void AliasGToF(Function *F, Function *G) { + if (!G->hasExternalLinkage() && !G->hasLocalLinkage() && !G->hasWeakLinkage()) + return ThunkGToF(F, G); + + GlobalAlias *GA = new GlobalAlias( + G->getType(), G->getLinkage(), "", + ConstantExpr::getBitCast(F, G->getType()), G->getParent()); + F->setAlignment(std::max(F->getAlignment(), G->getAlignment())); + GA->takeName(G); + GA->setVisibility(G->getVisibility()); + G->replaceAllUsesWith(GA); + G->eraseFromParent(); +} + +static bool fold(std::vector &FnVec, unsigned i, unsigned j) { Function *F = FnVec[i]; Function *G = FnVec[j]; - if (!F->mayBeOverridden()) { - if (G->hasLocalLinkage()) { - F->setAlignment(std::max(F->getAlignment(), G->getAlignment())); - G->replaceAllUsesWith(F); - G->eraseFromParent(); - ++NumFunctionsMerged; - return true; - } + LinkageCategory catF = categorize(F); + LinkageCategory catG = categorize(G); - if (G->hasExternalLinkage() || G->hasWeakLinkage()) { - GlobalAlias *GA = new GlobalAlias(G->getType(), G->getLinkage(), "", - F, G->getParent()); - F->setAlignment(std::max(F->getAlignment(), G->getAlignment())); - GA->takeName(G); - GA->setVisibility(G->getVisibility()); - G->replaceAllUsesWith(GA); - G->eraseFromParent(); - ++NumFunctionsMerged; - return true; - } + if (catF == ExternalWeak || (catF == Internal && catG == ExternalStrong)) { + std::swap(FnVec[i], FnVec[j]); + std::swap(F, G); + std::swap(catF, catG); } - if (F->hasWeakLinkage() && G->hasWeakLinkage()) { - GlobalAlias *GA_F = new GlobalAlias(F->getType(), F->getLinkage(), "", - 0, F->getParent()); - GA_F->takeName(F); - GA_F->setVisibility(F->getVisibility()); - F->setAlignment(std::max(F->getAlignment(), G->getAlignment())); - F->replaceAllUsesWith(GA_F); - F->setName("folded." + GA_F->getName()); - F->setLinkage(GlobalValue::ExternalLinkage); - GA_F->setAliasee(F); + switch (catF) { + case ExternalStrong: + switch (catG) { + case ExternalStrong: + case ExternalWeak: + ThunkGToF(F, G); + break; + case Internal: + if (G->hasAddressTaken()) + ThunkGToF(F, G); + else + AliasGToF(F, G); + break; + } + break; - GlobalAlias *GA_G = new GlobalAlias(G->getType(), G->getLinkage(), "", - F, G->getParent()); - GA_G->takeName(G); - GA_G->setVisibility(G->getVisibility()); - G->replaceAllUsesWith(GA_G); - G->eraseFromParent(); + case ExternalWeak: + return false; - ++NumFunctionsMerged; - return true; + case Internal: + switch (catG) { + case ExternalStrong: + assert(0); + // fall-through + case ExternalWeak: + if (F->hasAddressTaken()) + ThunkGToF(F, G); + else + AliasGToF(F, G); + break; + case Internal: { + bool addrTakenF = F->hasAddressTaken(); + bool addrTakenG = G->hasAddressTaken(); + if (!addrTakenF && addrTakenG) { + std::swap(FnVec[i], FnVec[j]); + std::swap(F, G); + std::swap(addrTakenF, addrTakenG); + } + + if (addrTakenF && addrTakenG) { + ThunkGToF(F, G); + } else { + assert(!addrTakenG); + AliasGToF(F, G); + } + } break; + } + break; } - DOUT << "Failed on " << F->getName() << " and " << G->getName() << "\n"; - - ++NumMergeFails; - return false; + ++NumFunctionsMerged; + return true; } -static bool hasAddressTaken(User *U) { - for (User::use_iterator I = U->use_begin(), E = U->use_end(); I != E; ++I) { - User *Use = *I; - - // 'call (bitcast @F to ...)' happens a lot. - while (isa(Use) && Use->hasOneUse()) { - Use = *Use->use_begin(); - } - - if (isa(Use)) { - if (hasAddressTaken(Use)) - return true; - } - - if (!isa(Use) && !isa(Use)) - return true; - - // Make sure we aren't passing U as a parameter to call instead of the - // callee. - if (CallSite(cast(Use)).hasArgument(U)) - return true; - } - - return false; -} +// ===----------------------------------------------------------------------=== +// Pass definition +// ===----------------------------------------------------------------------=== bool MergeFunctions::runOnModule(Module &M) { bool Changed = false; @@ -329,25 +590,19 @@ bool MergeFunctions::runOnModule(Module &M) { if (F->isDeclaration() || F->isIntrinsic()) continue; - if (!F->hasLocalLinkage() && !F->hasExternalLinkage() && - !F->hasWeakLinkage()) - continue; - - if (hasAddressTaken(F)) - continue; - FnMap[hash(F)].push_back(F); } - // TODO: instead of running in a loop, we could also fold functions in callgraph - // order. Constructing the CFG probably isn't cheaper than just running in a loop. + // TODO: instead of running in a loop, we could also fold functions in + // callgraph order. Constructing the CFG probably isn't cheaper than just + // running in a loop, unless it happened to already be available. bool LocalChanged; do { LocalChanged = false; + DOUT << "size: " << FnMap.size() << "\n"; for (std::map >::iterator I = FnMap.begin(), E = FnMap.end(); I != E; ++I) { - DOUT << "size: " << FnMap.size() << "\n"; std::vector &FnVec = I->second; DOUT << "hash (" << I->first << "): " << FnVec.size() << "\n";