From 8e5d01cd6efa99a26c8584711a7e8abbcf14c333 Mon Sep 17 00:00:00 2001 From: Che-Liang Chiou Date: Thu, 10 Feb 2011 12:01:24 +0000 Subject: [PATCH] ptx: add passing parameter to kernel functions git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125279 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/PTX/PTXAsmPrinter.cpp | 12 +++- lib/Target/PTX/PTXISelDAGToDAG.cpp | 19 +++++- lib/Target/PTX/PTXISelLowering.cpp | 83 +++++++++---------------- lib/Target/PTX/PTXISelLowering.h | 1 + lib/Target/PTX/PTXInstrInfo.h | 24 +++++++ lib/Target/PTX/PTXInstrInfo.td | 10 ++- lib/Target/PTX/PTXMFInfoExtract.cpp | 4 +- lib/Target/PTX/PTXMachineFunctionInfo.h | 4 +- test/CodeGen/PTX/exit.ll | 9 ++- 9 files changed, 104 insertions(+), 62 deletions(-) diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp index 872287eeea8..a6059974ab3 100644 --- a/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/lib/Target/PTX/PTXAsmPrinter.cpp @@ -38,12 +38,11 @@ using namespace llvm; static cl::opt -OptPTXVersion("ptx-version", cl::desc("Set PTX version"), - cl::init("1.4")); +OptPTXVersion("ptx-version", cl::desc("Set PTX version"), cl::init("1.4")); static cl::opt OptPTXTarget("ptx-target", cl::desc("Set GPU target (comma-separated list)"), - cl::init("sm_10")); + cl::init("sm_10")); namespace { class PTXAsmPrinter : public AsmPrinter { @@ -67,6 +66,8 @@ public: void printOperand(const MachineInstr *MI, int opNum, raw_ostream &OS); void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &OS, const char *Modifier = 0); + void printParamOperand(const MachineInstr *MI, int opNum, raw_ostream &OS, + const char *Modifier = 0); // autogen'd. void printInstruction(const MachineInstr *MI, raw_ostream &OS); @@ -231,6 +232,11 @@ void PTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum, printOperand(MI, opNum+1, OS); } +void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum, + raw_ostream &OS, const char *Modifier) { + OS << PARAM_PREFIX << (int) MI->getOperand(opNum).getImm() + 1; +} + void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) { // Check to see if this is a special global used by LLVM, if so, emit it. if (EmitSpecialLLVMGlobal(gv)) diff --git a/lib/Target/PTX/PTXISelDAGToDAG.cpp b/lib/Target/PTX/PTXISelDAGToDAG.cpp index 294b62f08f3..efb0e8b1af7 100644 --- a/lib/Target/PTX/PTXISelDAGToDAG.cpp +++ b/lib/Target/PTX/PTXISelDAGToDAG.cpp @@ -40,6 +40,8 @@ class PTXDAGToDAGISel : public SelectionDAGISel { #include "PTXGenDAGISel.inc" private: + SDNode *SelectREAD_PARAM(SDNode *Node); + bool isImm(const SDValue &operand); bool SelectImm(const SDValue &operand, SDValue &imm); }; // class PTXDAGToDAGISel @@ -57,8 +59,21 @@ PTXDAGToDAGISel::PTXDAGToDAGISel(PTXTargetMachine &TM, : SelectionDAGISel(TM, OptLevel) {} SDNode *PTXDAGToDAGISel::Select(SDNode *Node) { - // SelectCode() is auto'gened - return SelectCode(Node); + if (Node->getOpcode() == PTXISD::READ_PARAM) + return SelectREAD_PARAM(Node); + else + return SelectCode(Node); +} + +SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) { + SDValue index = Node->getOperand(1); + DebugLoc dl = Node->getDebugLoc(); + + if (index.getOpcode() != ISD::TargetConstant) + llvm_unreachable("READ_PARAM: index is not ISD::TargetConstant"); + + return PTXInstrInfo:: + GetPTXMachineNode(CurDAG, PTX::LDpi, dl, MVT::i32, index); } // Match memory operand of the form [reg+reg] diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index f05bd47b7fe..e6d44907ed3 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -47,9 +47,14 @@ SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const { switch (Opcode) { - default: llvm_unreachable("Unknown opcode"); - case PTXISD::EXIT: return "PTXISD::EXIT"; - case PTXISD::RET: return "PTXISD::RET"; + default: + llvm_unreachable("Unknown opcode"); + case PTXISD::READ_PARAM: + return "PTXISD::READ_PARAM"; + case PTXISD::EXIT: + return "PTXISD::EXIT"; + case PTXISD::RET: + return "PTXISD::RET"; } } @@ -86,42 +91,6 @@ struct argmap_entry { }; } // end anonymous namespace -static SDValue lower_kernel_argument(int i, - SDValue Chain, - DebugLoc dl, - MVT::SimpleValueType VT, - argmap_entry *entry, - SelectionDAG &DAG, - unsigned *argreg) { - // TODO - llvm_unreachable("Not implemented yet"); -} - -static SDValue lower_device_argument(int i, - SDValue Chain, - DebugLoc dl, - MVT::SimpleValueType VT, - argmap_entry *entry, - SelectionDAG &DAG, - unsigned *argreg) { - MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo(); - - unsigned preg = *++(entry->loc); // allocate start from register 1 - unsigned vreg = RegInfo.createVirtualRegister(entry->RC); - RegInfo.addLiveIn(preg, vreg); - - *argreg = preg; - return DAG.getCopyFromReg(Chain, dl, vreg, VT); -} - -typedef SDValue (*lower_argument_func)(int i, - SDValue Chain, - DebugLoc dl, - MVT::SimpleValueType VT, - argmap_entry *entry, - SelectionDAG &DAG, - unsigned *argreg); - SDValue PTXTargetLowering:: LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, @@ -135,22 +104,22 @@ SDValue PTXTargetLowering:: MachineFunction &MF = DAG.getMachineFunction(); PTXMachineFunctionInfo *MFI = MF.getInfo(); - lower_argument_func lower_argument; - switch (CallConv) { default: llvm_unreachable("Unsupported calling convention"); break; case CallingConv::PTX_Kernel: - MFI->setKernel(); - lower_argument = lower_kernel_argument; + MFI->setKernel(true); break; case CallingConv::PTX_Device: MFI->setKernel(false); - lower_argument = lower_device_argument; break; } + // Make sure we don't add argument registers twice + if (MFI->isDoneAddArg()) + llvm_unreachable("cannot add argument registers twice"); + // Reset argmap before allocation for (struct argmap_entry *i = argmap, *e = argmap + array_lengthof(argmap); i != e; ++ i) @@ -164,17 +133,27 @@ SDValue PTXTargetLowering:: if (entry == argmap + array_lengthof(argmap)) llvm_unreachable("Type of argument is not supported"); - unsigned reg; - SDValue arg = lower_argument(i, Chain, dl, VT, entry, DAG, ®); - InVals.push_back(arg); + if (MFI->isKernel() && entry->RC == PTX::PredsRegisterClass) + llvm_unreachable("cannot pass preds to kernel"); - if (!MFI->isDoneAddArg()) - MFI->addArgReg(reg); + MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo(); + + unsigned preg = *++(entry->loc); // allocate start from register 1 + unsigned vreg = RegInfo.createVirtualRegister(entry->RC); + RegInfo.addLiveIn(preg, vreg); + + MFI->addArgReg(preg); + + SDValue inval; + if (MFI->isKernel()) + inval = DAG.getNode(PTXISD::READ_PARAM, dl, VT, Chain, + DAG.getTargetConstant(i, MVT::i32)); + else + inval = DAG.getCopyFromReg(Chain, dl, vreg, VT); + InVals.push_back(inval); } - // Make sure we don't add argument registers twice - if (!MFI->isDoneAddArg()) - MFI->doneAddArg(); + MFI->doneAddArg(); return Chain; } diff --git a/lib/Target/PTX/PTXISelLowering.h b/lib/Target/PTX/PTXISelLowering.h index 14f2fc01467..b03a9f66630 100644 --- a/lib/Target/PTX/PTXISelLowering.h +++ b/lib/Target/PTX/PTXISelLowering.h @@ -24,6 +24,7 @@ class PTXTargetMachine; namespace PTXISD { enum NodeType { FIRST_NUMBER = ISD::BUILTIN_OP_END, + READ_PARAM, EXIT, RET }; diff --git a/lib/Target/PTX/PTXInstrInfo.h b/lib/Target/PTX/PTXInstrInfo.h index 9d9ffe1d23a..e7f00f09c2f 100644 --- a/lib/Target/PTX/PTXInstrInfo.h +++ b/lib/Target/PTX/PTXInstrInfo.h @@ -15,6 +15,8 @@ #define PTX_INSTR_INFO_H #include "PTXRegisterInfo.h" +#include "llvm/CodeGen/SelectionDAG.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/Target/TargetInstrInfo.h" namespace llvm { @@ -45,6 +47,28 @@ class PTXInstrInfo : public TargetInstrInfoImpl { virtual bool isMoveInstr(const MachineInstr& MI, unsigned &SrcReg, unsigned &DstReg, unsigned &SrcSubIdx, unsigned &DstSubIdx) const; + + // static helper routines + + static MachineSDNode *GetPTXMachineNode(SelectionDAG *DAG, unsigned Opcode, + DebugLoc dl, EVT VT, + SDValue Op1) { + SDValue pred_reg = DAG->getRegister(0, MVT::i1); + SDValue pred_imm = DAG->getTargetConstant(0, MVT::i32); + SDValue ops[] = { Op1, pred_reg, pred_imm }; + return DAG->getMachineNode(Opcode, dl, VT, ops, array_lengthof(ops)); + } + + static MachineSDNode *GetPTXMachineNode(SelectionDAG *DAG, unsigned Opcode, + DebugLoc dl, EVT VT, + SDValue Op1, + SDValue Op2) { + SDValue pred_reg = DAG->getRegister(0, MVT::i1); + SDValue pred_imm = DAG->getTargetConstant(0, MVT::i32); + SDValue ops[] = { Op1, Op2, pred_reg, pred_imm }; + return DAG->getMachineNode(Opcode, dl, VT, ops, array_lengthof(ops)); + } + }; // class PTXInstrInfo } // namespace llvm diff --git a/lib/Target/PTX/PTXInstrInfo.td b/lib/Target/PTX/PTXInstrInfo.td index 8e4b7203084..9a747788f6a 100644 --- a/lib/Target/PTX/PTXInstrInfo.td +++ b/lib/Target/PTX/PTXInstrInfo.td @@ -120,6 +120,10 @@ def MEMii : Operand { let PrintMethod = "printMemOperand"; let MIOperandInfo = (ops i32imm, i32imm); } +def MEMpi : Operand { + let PrintMethod = "printParamOperand"; + let MIOperandInfo = (ops i32imm); +} //===----------------------------------------------------------------------===// // PTX Specific Node Definitions @@ -236,9 +240,13 @@ defm LDl : PTX_LD<"ld.local", RRegs32, load_local>; defm LDp : PTX_LD<"ld.param", RRegs32, load_parameter>; defm LDs : PTX_LD<"ld.shared", RRegs32, load_shared>; +def LDpi : InstPTX<(outs RRegs32:$d), (ins MEMpi:$a), + "ld.param.%type\t$d, [$a]", []>; + defm STg : PTX_ST<"st.global", RRegs32, store_global>; defm STl : PTX_ST<"st.local", RRegs32, store_local>; -defm STp : PTX_ST<"st.param", RRegs32, store_parameter>; +// Store to parameter state space requires PTX 2.0 or higher? +// defm STp : PTX_ST<"st.param", RRegs32, store_parameter>; defm STs : PTX_ST<"st.shared", RRegs32, store_shared>; ///===- Control Flow Instructions -----------------------------------------===// diff --git a/lib/Target/PTX/PTXMFInfoExtract.cpp b/lib/Target/PTX/PTXMFInfoExtract.cpp index 68b641b89a0..b37c740006f 100644 --- a/lib/Target/PTX/PTXMFInfoExtract.cpp +++ b/lib/Target/PTX/PTXMFInfoExtract.cpp @@ -67,7 +67,9 @@ bool PTXMFInfoExtract::runOnMachineFunction(MachineFunction &MF) { // FIXME: This is a slow linear scanning for (unsigned reg = PTX::NoRegister + 1; reg < PTX::NUM_TARGET_REGS; ++reg) - if (MRI.isPhysRegUsed(reg) && reg != retreg && !MFI->isArgReg(reg)) + if (MRI.isPhysRegUsed(reg) && + reg != retreg && + (MFI->isKernel() || !MFI->isArgReg(reg))) MFI->addLocalVarReg(reg); // Notify MachineFunctionInfo that I've done adding local var reg diff --git a/lib/Target/PTX/PTXMachineFunctionInfo.h b/lib/Target/PTX/PTXMachineFunctionInfo.h index 6fbad5f8949..56d044b5fc0 100644 --- a/lib/Target/PTX/PTXMachineFunctionInfo.h +++ b/lib/Target/PTX/PTXMachineFunctionInfo.h @@ -31,8 +31,8 @@ private: public: PTXMachineFunctionInfo(MachineFunction &MF) : is_kernel(false), reg_ret(PTX::NoRegister), _isDoneAddArg(false) { - reg_arg.reserve(32); - reg_local_var.reserve(64); + reg_arg.reserve(8); + reg_local_var.reserve(32); } void setKernel(bool _is_kernel=true) { is_kernel = _is_kernel; } diff --git a/test/CodeGen/PTX/exit.ll b/test/CodeGen/PTX/exit.ll index 396898b623c..4071babb80c 100644 --- a/test/CodeGen/PTX/exit.ll +++ b/test/CodeGen/PTX/exit.ll @@ -3,5 +3,12 @@ define ptx_kernel void @t1() { ; CHECK: exit; ; CHECK-NOT: ret; - ret void + ret void +} + +define ptx_kernel void @t2(i32* %p, i32 %x) { + store i32 %x, i32* %p +; CHECK: exit; +; CHECK-NOT: ret; + ret void }