diff --git a/lib/AsmParser/LLParser.cpp b/lib/AsmParser/LLParser.cpp index 0fa53019df1..13b648ce9eb 100644 --- a/lib/AsmParser/LLParser.cpp +++ b/lib/AsmParser/LLParser.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/ValueSymbolTable.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -129,28 +130,11 @@ bool LLParser::ValidateEndOfModule() { } } - // If there are entries in ForwardRefBlockAddresses at this point, they are - // references after the function was defined. Resolve those now. - while (!ForwardRefBlockAddresses.empty()) { - // Okay, we are referencing an already-parsed function, resolve them now. - Function *TheFn = nullptr; - const ValID &Fn = ForwardRefBlockAddresses.begin()->first; - if (Fn.Kind == ValID::t_GlobalName) - TheFn = M->getFunction(Fn.StrVal); - else if (Fn.UIntVal < NumberedVals.size()) - TheFn = dyn_cast(NumberedVals[Fn.UIntVal]); - - if (!TheFn) - return Error(Fn.Loc, "unknown function referenced by blockaddress"); - - // Resolve all these references. - if (ResolveForwardRefBlockAddresses(TheFn, - ForwardRefBlockAddresses.begin()->second, - nullptr)) - return true; - - ForwardRefBlockAddresses.erase(ForwardRefBlockAddresses.begin()); - } + // If there are entries in ForwardRefBlockAddresses at this point, the + // function was never defined. + if (!ForwardRefBlockAddresses.empty()) + return Error(ForwardRefBlockAddresses.begin()->first.Loc, + "expected function name in blockaddress"); for (unsigned i = 0, e = NumberedTypes.size(); i != e; ++i) if (NumberedTypes[i].second.isValid()) @@ -193,38 +177,6 @@ bool LLParser::ValidateEndOfModule() { return false; } -bool LLParser::ResolveForwardRefBlockAddresses(Function *TheFn, - std::vector > &Refs, - PerFunctionState *PFS) { - // Loop over all the references, resolving them. - for (unsigned i = 0, e = Refs.size(); i != e; ++i) { - BasicBlock *Res; - if (PFS) { - if (Refs[i].first.Kind == ValID::t_LocalName) - Res = PFS->GetBB(Refs[i].first.StrVal, Refs[i].first.Loc); - else - Res = PFS->GetBB(Refs[i].first.UIntVal, Refs[i].first.Loc); - } else if (Refs[i].first.Kind == ValID::t_LocalID) { - return Error(Refs[i].first.Loc, - "cannot take address of numeric label after the function is defined"); - } else { - Res = dyn_cast_or_null( - TheFn->getValueSymbolTable().lookup(Refs[i].first.StrVal)); - } - - if (!Res) - return Error(Refs[i].first.Loc, - "referenced value is not a basic block"); - - // Get the BlockAddress for this and update references to use it. - BlockAddress *BA = BlockAddress::get(TheFn, Res); - Refs[i].second->replaceAllUsesWith(BA); - Refs[i].second->eraseFromParent(); - } - return false; -} - - //===----------------------------------------------------------------------===// // Top-Level Entities //===----------------------------------------------------------------------===// @@ -2186,28 +2138,6 @@ LLParser::PerFunctionState::~PerFunctionState() { } bool LLParser::PerFunctionState::FinishFunction() { - // Check to see if someone took the address of labels in this block. - if (!P.ForwardRefBlockAddresses.empty()) { - ValID FunctionID; - if (!F.getName().empty()) { - FunctionID.Kind = ValID::t_GlobalName; - FunctionID.StrVal = F.getName(); - } else { - FunctionID.Kind = ValID::t_GlobalID; - FunctionID.UIntVal = FunctionNumber; - } - - std::map > >::iterator - FRBAI = P.ForwardRefBlockAddresses.find(FunctionID); - if (FRBAI != P.ForwardRefBlockAddresses.end()) { - // Resolve all these references. - if (P.ResolveForwardRefBlockAddresses(&F, FRBAI->second, this)) - return true; - - P.ForwardRefBlockAddresses.erase(FRBAI); - } - } - if (!ForwardRefVals.empty()) return P.Error(ForwardRefVals.begin()->second.second, "use of undefined value '%" + ForwardRefVals.begin()->first + @@ -2593,12 +2523,56 @@ bool LLParser::ParseValID(ValID &ID, PerFunctionState *PFS) { if (Label.Kind != ValID::t_LocalID && Label.Kind != ValID::t_LocalName) return Error(Label.Loc, "expected basic block name in blockaddress"); - // Make a global variable as a placeholder for this reference. - GlobalVariable *FwdRef = new GlobalVariable(*M, Type::getInt8Ty(Context), - false, GlobalValue::InternalLinkage, - nullptr, ""); - ForwardRefBlockAddresses[Fn].push_back(std::make_pair(Label, FwdRef)); - ID.ConstantVal = FwdRef; + // Try to find the function (but skip it if it's forward-referenced). + GlobalValue *GV = nullptr; + if (Fn.Kind == ValID::t_GlobalID) { + if (Fn.UIntVal < NumberedVals.size()) + GV = NumberedVals[Fn.UIntVal]; + } else if (!ForwardRefVals.count(Fn.StrVal)) { + GV = M->getNamedValue(Fn.StrVal); + } + Function *F = nullptr; + if (GV) { + // Confirm that it's actually a function with a definition. + if (!isa(GV)) + return Error(Fn.Loc, "expected function name in blockaddress"); + F = cast(GV); + if (F->isDeclaration()) + return Error(Fn.Loc, "cannot take blockaddress inside a declaration"); + } + + if (!F) { + // Make a global variable as a placeholder for this reference. + GlobalValue *&FwdRef = ForwardRefBlockAddresses[Fn][Label]; + if (!FwdRef) + FwdRef = new GlobalVariable(*M, Type::getInt8Ty(Context), false, + GlobalValue::InternalLinkage, nullptr, ""); + ID.ConstantVal = FwdRef; + ID.Kind = ValID::t_Constant; + return false; + } + + // We found the function; now find the basic block. Don't use PFS, since we + // might be inside a constant expression. + BasicBlock *BB; + if (BlockAddressPFS && F == &BlockAddressPFS->getFunction()) { + if (Label.Kind == ValID::t_LocalID) + BB = BlockAddressPFS->GetBB(Label.UIntVal, Label.Loc); + else + BB = BlockAddressPFS->GetBB(Label.StrVal, Label.Loc); + if (!BB) + return Error(Label.Loc, "referenced value is not a basic block"); + } else { + if (Label.Kind == ValID::t_LocalID) + return Error(Label.Loc, "cannot take address of numeric label after " + "the function is defined"); + BB = dyn_cast_or_null( + F->getValueSymbolTable().lookup(Label.StrVal)); + if (!BB) + return Error(Label.Loc, "referenced value is not a basic block"); + } + + ID.ConstantVal = BlockAddress::get(F, BB); ID.Kind = ValID::t_Constant; return false; } @@ -2913,7 +2887,7 @@ bool LLParser::parseOptionalComdat(Comdat *&C) { /// ParseGlobalValueVector /// ::= /*empty*/ /// ::= TypeAndValue (',' TypeAndValue)* -bool LLParser::ParseGlobalValueVector(SmallVectorImpl &Elts) { +bool LLParser::ParseGlobalValueVector(SmallVectorImpl &Elts) { // Empty list. if (Lex.getKind() == lltok::rbrace || Lex.getKind() == lltok::rsquare || @@ -3341,9 +3315,60 @@ bool LLParser::ParseFunctionHeader(Function *&Fn, bool isDefine) { ArgList[i].Name + "'"); } + if (isDefine) + return false; + + // Check the a declaration has no block address forward references. + ValID ID; + if (FunctionName.empty()) { + ID.Kind = ValID::t_GlobalID; + ID.UIntVal = NumberedVals.size() - 1; + } else { + ID.Kind = ValID::t_GlobalName; + ID.StrVal = FunctionName; + } + auto Blocks = ForwardRefBlockAddresses.find(ID); + if (Blocks != ForwardRefBlockAddresses.end()) + return Error(Blocks->first.Loc, + "cannot take blockaddress inside a declaration"); return false; } +bool LLParser::PerFunctionState::resolveForwardRefBlockAddresses() { + ValID ID; + if (FunctionNumber == -1) { + ID.Kind = ValID::t_GlobalName; + ID.StrVal = F.getName(); + } else { + ID.Kind = ValID::t_GlobalID; + ID.UIntVal = FunctionNumber; + } + + auto Blocks = P.ForwardRefBlockAddresses.find(ID); + if (Blocks == P.ForwardRefBlockAddresses.end()) + return false; + + for (const auto &I : Blocks->second) { + const ValID &BBID = I.first; + GlobalValue *GV = I.second; + + assert((BBID.Kind == ValID::t_LocalID || BBID.Kind == ValID::t_LocalName) && + "Expected local id or name"); + BasicBlock *BB; + if (BBID.Kind == ValID::t_LocalName) + BB = GetBB(BBID.StrVal, BBID.Loc); + else + BB = GetBB(BBID.UIntVal, BBID.Loc); + if (!BB) + return P.Error(BBID.Loc, "referenced value is not a basic block"); + + GV->replaceAllUsesWith(BlockAddress::get(&F, BB)); + GV->eraseFromParent(); + } + + P.ForwardRefBlockAddresses.erase(Blocks); + return false; +} /// ParseFunctionBody /// ::= '{' BasicBlock+ '}' @@ -3358,6 +3383,12 @@ bool LLParser::ParseFunctionBody(Function &Fn) { PerFunctionState PFS(*this, Fn, FunctionNumber); + // Resolve block addresses and allow basic blocks to be forward-declared + // within this function. + if (PFS.resolveForwardRefBlockAddresses()) + return true; + SaveAndRestore ScopeExit(BlockAddressPFS, &PFS); + // We need at least one basic block. if (Lex.getKind() == lltok::rbrace) return TokError("function body requires at least one basic block"); diff --git a/lib/AsmParser/LLParser.h b/lib/AsmParser/LLParser.h index a3eb8d3ed38..556b385c3c4 100644 --- a/lib/AsmParser/LLParser.h +++ b/lib/AsmParser/LLParser.h @@ -128,17 +128,21 @@ namespace llvm { // References to blockaddress. The key is the function ValID, the value is // a list of references to blocks in that function. - std::map > > - ForwardRefBlockAddresses; + std::map> ForwardRefBlockAddresses; + class PerFunctionState; + /// Reference to per-function state to allow basic blocks to be + /// forward-referenced by blockaddress instructions within the same + /// function. + PerFunctionState *BlockAddressPFS; // Attribute builder reference information. std::map > ForwardRefAttrGroups; std::map NumberedAttrBuilders; public: - LLParser(StringRef F, SourceMgr &SM, SMDiagnostic &Err, Module *m) : - Context(m->getContext()), Lex(F, SM, Err, m->getContext()), - M(m) {} + LLParser(StringRef F, SourceMgr &SM, SMDiagnostic &Err, Module *m) + : Context(m->getContext()), Lex(F, SM, Err, m->getContext()), M(m), + BlockAddressPFS(nullptr) {} bool Run(); LLVMContext &getContext() { return Context; } @@ -327,6 +331,8 @@ namespace llvm { /// unnamed. If there is an error, this returns null otherwise it returns /// the block being defined. BasicBlock *DefineBB(const std::string &Name, LocTy Loc); + + bool resolveForwardRefBlockAddresses(); }; bool ConvertValIDToValue(Type *Ty, ValID &ID, Value *&V, @@ -372,7 +378,7 @@ namespace llvm { bool ParseValID(ValID &ID, PerFunctionState *PFS = nullptr); bool ParseGlobalValue(Type *Ty, Constant *&V); bool ParseGlobalTypeAndValue(Constant *&V); - bool ParseGlobalValueVector(SmallVectorImpl &Elts); + bool ParseGlobalValueVector(SmallVectorImpl &Elts); bool parseOptionalComdat(Comdat *&C); bool ParseMetadataListValue(ValID &ID, PerFunctionState *PFS); bool ParseMetadataValue(ValID &ID, PerFunctionState *PFS); @@ -432,10 +438,6 @@ namespace llvm { int ParseGetElementPtr(Instruction *&I, PerFunctionState &PFS); int ParseExtractValue(Instruction *&I, PerFunctionState &PFS); int ParseInsertValue(Instruction *&I, PerFunctionState &PFS); - - bool ResolveForwardRefBlockAddresses(Function *TheFn, - std::vector > &Refs, - PerFunctionState *PFS); }; } // End llvm namespace