From c32c110b909ebea339fccf330774bebaef3ed16d Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Mon, 10 Mar 2014 20:05:42 +0000 Subject: [PATCH] Make sure NVPTX doesn't emit symbol names that aren't valid in PTX. NVPTX, like the other backends, relies on generic symbol name sanitizing done by MCSymbol. However, the ptxas assembler is more stringent and disallows some additional characters in symbol names. See PR19099 for more details. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@203483 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 59 ++++++++++++++++++++-------- lib/Target/NVPTX/NVPTXAsmPrinter.h | 5 +++ 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 25108bc39aa..0cbdcc49aa9 100644 --- a/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -684,7 +684,7 @@ void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { else O << ".func "; printReturnValStr(F, O); - O << *getSymbol(F) << "\n"; + O << getSymbolName(F) << "\n"; emitFunctionParamList(F, O); O << ";\n"; } @@ -1209,7 +1209,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, else O << getPTXFundamentalTypeStr(ETy, false); O << " "; - O << *getSymbol(GVar); + O << getSymbolName(GVar); // Ptx allows variable initilization only for constant and global state // spaces. @@ -1245,15 +1245,15 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, bufferAggregateConstant(Initializer, &aggBuffer); if (aggBuffer.numSymbols) { if (nvptxSubtarget.is64Bit()) { - O << " .u64 " << *getSymbol(GVar) << "["; + O << " .u64 " << getSymbolName(GVar) << "["; O << ElementSize / 8; } else { - O << " .u32 " << *getSymbol(GVar) << "["; + O << " .u32 " << getSymbolName(GVar) << "["; O << ElementSize / 4; } O << "]"; } else { - O << " .b8 " << *getSymbol(GVar) << "["; + O << " .b8 " << getSymbolName(GVar) << "["; O << ElementSize; O << "]"; } @@ -1261,7 +1261,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, aggBuffer.print(); O << "}"; } else { - O << " .b8 " << *getSymbol(GVar); + O << " .b8 " << getSymbolName(GVar); if (ElementSize) { O << "["; O << ElementSize; @@ -1269,7 +1269,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, } } } else { - O << " .b8 " << *getSymbol(GVar); + O << " .b8 " << getSymbolName(GVar); if (ElementSize) { O << "["; O << ElementSize; @@ -1376,7 +1376,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, O << " ."; O << getPTXFundamentalTypeStr(ETy); O << " "; - O << *getSymbol(GVar); + O << getSymbolName(GVar); return; } @@ -1391,7 +1391,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, case Type::ArrayTyID: case Type::VectorTyID: ElementSize = TD->getTypeStoreSize(ETy); - O << " .b8 " << *getSymbol(GVar) << "["; + O << " .b8 " << getSymbolName(GVar) << "["; if (ElementSize) { O << itostr(ElementSize); } @@ -1446,7 +1446,7 @@ void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I, int paramIndex, raw_ostream &O) { if ((nvptxSubtarget.getDrvInterface() == NVPTX::NVCL) || (nvptxSubtarget.getDrvInterface() == NVPTX::CUDA)) - O << *getSymbol(I->getParent()) << "_param_" << paramIndex; + O << getSymbolName(I->getParent()) << "_param_" << paramIndex; else { std::string argName = I->getName(); const char *p = argName.c_str(); @@ -1505,13 +1505,13 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { if (llvm::isImage(*I)) { std::string sname = I->getName(); if (llvm::isImageWriteOnly(*I)) - O << "\t.param .surfref " << *getSymbol(F) << "_param_" + O << "\t.param .surfref " << getSymbolName(F) << "_param_" << paramIndex; else // Default image is read_only - O << "\t.param .texref " << *getSymbol(F) << "_param_" + O << "\t.param .texref " << getSymbolName(F) << "_param_" << paramIndex; } else // Should be llvm::isSampler(*I) - O << "\t.param .samplerref " << *getSymbol(F) << "_param_" + O << "\t.param .samplerref " << getSymbolName(F) << "_param_" << paramIndex; continue; } @@ -1758,13 +1758,13 @@ void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { return; } if (const GlobalValue *GVar = dyn_cast(CPV)) { - O << *getSymbol(GVar); + O << getSymbolName(GVar); return; } if (const ConstantExpr *Cexpr = dyn_cast(CPV)) { const Value *v = Cexpr->stripPointerCasts(); if (const GlobalValue *GVar = dyn_cast(v)) { - O << *getSymbol(GVar); + O << getSymbolName(GVar); return; } else { O << *LowerConstant(CPV, *this); @@ -2078,7 +2078,7 @@ void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, break; case MachineOperand::MO_GlobalAddress: - O << *getSymbol(MO.getGlobal()); + O << getSymbolName(MO.getGlobal()); break; case MachineOperand::MO_MachineBasicBlock: @@ -2139,6 +2139,33 @@ LineReader *NVPTXAsmPrinter::getReader(std::string filename) { return reader; } +std::string NVPTXAsmPrinter::getSymbolName(const GlobalValue *GV) const { + // Obtain the original symbol name. + MCSymbol *Sym = getSymbol(GV); + std::string OriginalName; + raw_string_ostream OriginalNameStream(OriginalName); + Sym->print(OriginalNameStream); + OriginalNameStream.flush(); + + // MCSymbol already does symbol-name sanitizing, so names it produces are + // valid for object files. The only two characters valida in that context + // and indigestible by the PTX assembler are '.' and '@'. + std::string CleanName; + raw_string_ostream CleanNameStream(CleanName); + for (unsigned I = 0, E = OriginalName.size(); I != E; ++I) { + char C = OriginalName[I]; + if (C == '.') { + CleanNameStream << "_$_"; + } else if (C == '@') { + CleanNameStream << "_%_"; + } else { + CleanNameStream << C; + } + } + + return CleanNameStream.str(); +} + std::string LineReader::readLine(unsigned lineNum) { if (lineNum < theCurLine) { theCurLine = 0; diff --git a/lib/Target/NVPTX/NVPTXAsmPrinter.h b/lib/Target/NVPTX/NVPTXAsmPrinter.h index 71624200d0e..abce85c39d7 100644 --- a/lib/Target/NVPTX/NVPTXAsmPrinter.h +++ b/lib/Target/NVPTX/NVPTXAsmPrinter.h @@ -276,6 +276,11 @@ private: LineReader *reader; LineReader *getReader(std::string); + + // Get the symbol name of the given global symbol. + // + // Cleans up the name so it's a valid in PTX assembly. + std::string getSymbolName(const GlobalValue *GV) const; public: NVPTXAsmPrinter(TargetMachine &TM, MCStreamer &Streamer) : AsmPrinter(TM, Streamer),