From 96e6458903ab0799542365cac98653c207984162 Mon Sep 17 00:00:00 2001 From: Dan Bailey Date: Fri, 11 Nov 2011 14:45:12 +0000 Subject: [PATCH] allow non-device function calls in PTX when natively handling device-side printf git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144388 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp | 18 +++- lib/Target/PTX/PTXAsmPrinter.cpp | 30 ++++++ lib/Target/PTX/PTXAsmPrinter.h | 2 +- lib/Target/PTX/PTXISelLowering.cpp | 102 ++++++++++++++---- test/CodeGen/PTX/printf.ll | 25 +++++ 5 files changed, 154 insertions(+), 23 deletions(-) create mode 100644 test/CodeGen/PTX/printf.ll diff --git a/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp b/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp index aabb404dad6..2f6c92d11cc 100644 --- a/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp +++ b/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp @@ -96,9 +96,23 @@ void PTXInstPrinter::printCall(const MCInst *MI, raw_ostream &O) { O << "), "; } - O << *(MI->getOperand(Index++).getExpr()) << ", ("; - + const MCExpr* Expr = MI->getOperand(Index++).getExpr(); unsigned NumArgs = MI->getOperand(Index++).getImm(); + + // if the function call is to printf or puts, change to vprintf + if (const MCSymbolRefExpr *SymRefExpr = dyn_cast(Expr)) { + const MCSymbol &Sym = SymRefExpr->getSymbol(); + if (Sym.getName() == "printf" || Sym.getName() == "puts") { + O << "vprintf"; + } else { + O << Sym.getName(); + } + } else { + O << *Expr; + } + + O << ", ("; + if (NumArgs > 0) { printOperand(MI, Index++, O); for (unsigned i = 1; i < NumArgs; ++i) { diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp index 45a6afc8587..bdf238b1b04 100644 --- a/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/lib/Target/PTX/PTXAsmPrinter.cpp @@ -165,6 +165,11 @@ void PTXAsmPrinter::EmitStartOfAsmFile(Module &M) OutStreamer.AddBlankLine(); + // declare external functions + for (Module::const_iterator i = M.begin(), e = M.end(); + i != e; ++i) + EmitFunctionDeclaration(i); + // declare global variables for (Module::const_global_iterator i = M.global_begin(), e = M.global_end(); i != e; ++i) @@ -454,6 +459,31 @@ void PTXAsmPrinter::EmitFunctionEntryLabel() { OutStreamer.EmitRawText(os.str()); } +void PTXAsmPrinter::EmitFunctionDeclaration(const Function* func) +{ + const PTXSubtarget& ST = TM.getSubtarget(); + + std::string decl = ""; + + // hard-coded emission of extern vprintf function + + if (func->getName() == "printf" || func->getName() == "puts") { + decl += ".extern .func (.param .b32 __param_1) vprintf (.param .b"; + if (ST.is64Bit()) + decl += "64"; + else + decl += "32"; + decl += " __param_2, .param .b"; + if (ST.is64Bit()) + decl += "64"; + else + decl += "32"; + decl += " __param_3)\n"; + } + + OutStreamer.EmitRawText(Twine(decl)); +} + unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName, StringRef DirName) { // If FE did not provide a file name, then assume stdin. diff --git a/lib/Target/PTX/PTXAsmPrinter.h b/lib/Target/PTX/PTXAsmPrinter.h index 538c0802a27..d5ea4dbc59c 100644 --- a/lib/Target/PTX/PTXAsmPrinter.h +++ b/lib/Target/PTX/PTXAsmPrinter.h @@ -47,7 +47,7 @@ public: private: void EmitVariableDeclaration(const GlobalVariable *gv); - void EmitFunctionDeclaration(); + void EmitFunctionDeclaration(const Function* func); StringMap SourceIdMap; }; // class PTXAsmPrinter diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index 3307d91a618..7f55871f63b 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -20,6 +20,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" @@ -352,40 +353,101 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, SmallVectorImpl &InVals) const { MachineFunction& MF = DAG.getMachineFunction(); - PTXMachineFunctionInfo *MFI = MF.getInfo(); - PTXParamManager &PM = MFI->getParamManager(); - + PTXMachineFunctionInfo *PTXMFI = MF.getInfo(); + PTXParamManager &PM = PTXMFI->getParamManager(); + MachineFrameInfo *MFI = MF.getFrameInfo(); + assert(getTargetMachine().getSubtarget().callsAreHandled() && "Calls are not handled for the target device"); + // Identify the callee function + const GlobalValue *GV = cast(Callee)->getGlobal(); + const Function *function = cast(GV); + + // allow non-device calls only for printf + bool isPrintf = function->getName() == "printf" || function->getName() == "puts"; + + assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) && + "PTX function calls must be to PTX device functions"); + + unsigned outSize = isPrintf ? 2 : Outs.size(); + std::vector Ops; // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs] - Ops.resize(Outs.size() + Ins.size() + 4); + Ops.resize(outSize + Ins.size() + 4); Ops[0] = Chain; // Identify the callee function - const GlobalValue *GV = cast(Callee)->getGlobal(); - assert(cast(GV)->getCallingConv() == CallingConv::PTX_Device && - "PTX function calls must be to PTX device functions"); Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); Ops[Ins.size()+2] = Callee; - // Generate STORE_PARAM nodes for each function argument. In PTX, function - // arguments are explicitly stored into .param variables and passed as - // arguments. There is no register/stack-based calling convention in PTX. - Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32); - for (unsigned i = 0; i != OutVals.size(); ++i) { - unsigned Size = OutVals[i].getValueType().getSizeInBits(); - unsigned Param = PM.addLocalParam(Size); - const std::string &ParamName = PM.getParamName(Param); - SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), - MVT::Other); + // #Outs + Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32); + + if (isPrintf) { + // first argument is the address of the global string variable in memory + unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits()); + SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(), + MVT::Other); Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, - ParamValue, OutVals[i]); - Ops[i+Ins.size()+4] = ParamValue; - } + ParamValue0, OutVals[0]); + Ops[Ins.size()+4] = ParamValue0; + + // alignment is the maximum size of all the arguments + unsigned alignment = 0; + for (unsigned i = 1; i < OutVals.size(); ++i) { + alignment = std::max(alignment, + OutVals[i].getValueType().getSizeInBits()); + } + // size is the alignment multiplied by the number of arguments + unsigned size = alignment * (OutVals.size() - 1); + + // second argument is the address of the stack object (unless no arguments) + unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits()); + SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(), + MVT::Other); + Ops[Ins.size()+5] = ParamValue1; + + if (size > 0) + { + // create a local stack object to store the arguments + unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false); + SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy()); + + // store each of the arguments to the stack in turn + for (unsigned int i = 1; i != OutVals.size(); i++) { + SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy())); + Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr, + MachinePointerInfo(), + false, false, 0); + } + + // copy the address of the local frame index to get the address in non-local space + SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex); + + // store this address in the second argument + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr); + } + } + else + { + // Generate STORE_PARAM nodes for each function argument. In PTX, function + // arguments are explicitly stored into .param variables and passed as + // arguments. There is no register/stack-based calling convention in PTX. + for (unsigned i = 0; i != OutVals.size(); ++i) { + unsigned Size = OutVals[i].getValueType().getSizeInBits(); + unsigned Param = PM.addLocalParam(Size); + const std::string &ParamName = PM.getParamName(Param); + SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), + MVT::Other); + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, + ParamValue, OutVals[i]); + Ops[i+Ins.size()+4] = ParamValue; + } + } + std::vector InParams; // Generate list of .param variables to hold the return value(s). diff --git a/test/CodeGen/PTX/printf.ll b/test/CodeGen/PTX/printf.ll new file mode 100644 index 00000000000..f901b2055f0 --- /dev/null +++ b/test/CodeGen/PTX/printf.ll @@ -0,0 +1,25 @@ +; RUN: llc < %s -march=ptx64 -mattr=+ptx20,+sm20 | FileCheck %s + +declare i32 @printf(i8*, ...) + +@str = private unnamed_addr constant [6 x i8] c"test\0A\00" + +define ptx_device void @t1_printf() { +; CHECK: mov.u64 %rd{{[0-9]+}}, $L__str; +; CHECK: call.uni (__localparam_{{[0-9]+}}), vprintf, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}}); +; CHECK: ret; + %1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([6 x i8]* @str, i64 0, i64 0)) + ret void +} + +@str2 = private unnamed_addr constant [11 x i8] c"test = %f\0A\00" + +define ptx_device void @t2_printf() { +; CHECK: .local .align 8 .b8 __local{{[0-9]+}}[{{[0-9]+}}]; +; CHECK: mov.u64 %rd{{[0-9]+}}, $L__str2; +; CHECK: cvta.local.u64 %rd{{[0-9]+}}, __local{{[0-9+]}}; +; CHECK: call.uni (__localparam_{{[0-9]+}}), vprintf, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}}); +; CHECK: ret; + %1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([11 x i8]* @str2, i64 0, i64 0), double 0x3FF3333340000000) + ret void +}