PTX: Use .param space for parameters in device functions for SM >= 2.0

FIXME: DCE is eliminating the final st.param.x calls, figure out why

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@133732 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Justin Holewinski
2011-06-23 18:10:03 +00:00
parent 28578e3137
commit 67a9184861
6 changed files with 126 additions and 33 deletions

View File

@@ -417,6 +417,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>(); const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
const bool isKernel = MFI->isKernel(); const bool isKernel = MFI->isKernel();
const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
std::string decl = isKernel ? ".entry" : ".func"; std::string decl = isKernel ? ".entry" : ".func";
@@ -452,7 +453,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
if (i != b) { if (i != b) {
decl += ", "; decl += ", ";
} }
if (isKernel) { if (isKernel || ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) {
decl += ".param .b"; decl += ".param .b";
decl += utostr(*i); decl += utostr(*i);
decl += " "; decl += " ";

View File

@@ -15,6 +15,7 @@
#include "PTXTargetMachine.h" #include "PTXTargetMachine.h"
#include "llvm/CodeGen/SelectionDAGISel.h" #include "llvm/CodeGen/SelectionDAGISel.h"
#include "llvm/DerivedTypes.h" #include "llvm/DerivedTypes.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
using namespace llvm; using namespace llvm;
@@ -42,7 +43,8 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
private: private:
SDNode *SelectREAD_PARAM(SDNode *Node); SDNode *SelectREAD_PARAM(SDNode *Node);
//SDNode *SelectSTORE_PARAM(SDNode *Node);
// We need this only because we can't match intruction BRAdp // We need this only because we can't match intruction BRAdp
// pattern (PTXbrcond bb:$d, ...) in PTXInstrInfo.td // pattern (PTXbrcond bb:$d, ...) in PTXInstrInfo.td
SDNode *SelectBRCOND(SDNode *Node); SDNode *SelectBRCOND(SDNode *Node);
@@ -69,6 +71,8 @@ SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
switch (Node->getOpcode()) { switch (Node->getOpcode()) {
case PTXISD::READ_PARAM: case PTXISD::READ_PARAM:
return SelectREAD_PARAM(Node); return SelectREAD_PARAM(Node);
// case PTXISD::STORE_PARAM:
// return SelectSTORE_PARAM(Node);
case ISD::BRCOND: case ISD::BRCOND:
return SelectBRCOND(Node); return SelectBRCOND(Node);
default: default:
@@ -86,20 +90,15 @@ SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) {
if (Node->getValueType(0) == MVT::i16) { if (Node->getValueType(0) == MVT::i16) {
opcode = PTX::LDpiU16; opcode = PTX::LDpiU16;
} } else if (Node->getValueType(0) == MVT::i32) {
else if (Node->getValueType(0) == MVT::i32) {
opcode = PTX::LDpiU32; opcode = PTX::LDpiU32;
} } else if (Node->getValueType(0) == MVT::i64) {
else if (Node->getValueType(0) == MVT::i64) {
opcode = PTX::LDpiU64; opcode = PTX::LDpiU64;
} } else if (Node->getValueType(0) == MVT::f32) {
else if (Node->getValueType(0) == MVT::f32) {
opcode = PTX::LDpiF32; opcode = PTX::LDpiF32;
} } else if (Node->getValueType(0) == MVT::f64) {
else if (Node->getValueType(0) == MVT::f64) {
opcode = PTX::LDpiF64; opcode = PTX::LDpiF64;
} } else {
else {
llvm_unreachable("Unknown parameter type for ld.param"); llvm_unreachable("Unknown parameter type for ld.param");
} }
@@ -107,6 +106,42 @@ SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) {
GetPTXMachineNode(CurDAG, opcode, dl, Node->getValueType(0), index); GetPTXMachineNode(CurDAG, opcode, dl, Node->getValueType(0), index);
} }
// SDNode *PTXDAGToDAGISel::SelectSTORE_PARAM(SDNode *Node) {
// SDValue Chain = Node->getOperand(0);
// SDValue index = Node->getOperand(1);
// SDValue value = Node->getOperand(2);
// DebugLoc dl = Node->getDebugLoc();
// unsigned opcode;
// if (index.getOpcode() != ISD::TargetConstant)
// llvm_unreachable("STORE_PARAM: index is not ISD::TargetConstant");
// if (value->getValueType(0) == MVT::i16) {
// opcode = PTX::STpiU16;
// } else if (value->getValueType(0) == MVT::i32) {
// opcode = PTX::STpiU32;
// } else if (value->getValueType(0) == MVT::i64) {
// opcode = PTX::STpiU64;
// } else if (value->getValueType(0) == MVT::f32) {
// opcode = PTX::STpiF32;
// } else if (value->getValueType(0) == MVT::f64) {
// opcode = PTX::STpiF64;
// } else {
// llvm_unreachable("Unknown parameter type for st.param");
// }
// SDVTList VTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
// SDValue PredReg = CurDAG->getRegister(PTX::NoRegister, MVT::i1);
// SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32);
// SDValue Ops[] = { Chain, index, value, PredReg, PredOp };
// //SDNode *RetNode = PTXInstrInfo::
// // GetPTXMachineNode(CurDAG, opcode, dl, VTs, index, value);
// SDNode *RetNode = CurDAG->getMachineNode(opcode, dl, VTs, Ops, array_lengthof(Ops));
// DEBUG(dbgs() << "SelectSTORE_PARAM: Selected: ");
// RetNode->dumpr(CurDAG);
// return RetNode;
// }
SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) { SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) {
assert(Node->getNumOperands() >= 3); assert(Node->getNumOperands() >= 3);

View File

@@ -15,6 +15,7 @@
#include "PTXISelLowering.h" #include "PTXISelLowering.h"
#include "PTXMachineFunctionInfo.h" #include "PTXMachineFunctionInfo.h"
#include "PTXRegisterInfo.h" #include "PTXRegisterInfo.h"
#include "PTXSubtarget.h"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineFunction.h"
@@ -106,6 +107,8 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "PTXISD::COPY_ADDRESS"; return "PTXISD::COPY_ADDRESS";
case PTXISD::READ_PARAM: case PTXISD::READ_PARAM:
return "PTXISD::READ_PARAM"; return "PTXISD::READ_PARAM";
case PTXISD::STORE_PARAM:
return "PTXISD::STORE_PARAM";
case PTXISD::EXIT: case PTXISD::EXIT:
return "PTXISD::EXIT"; return "PTXISD::EXIT";
case PTXISD::RET: case PTXISD::RET:
@@ -192,6 +195,7 @@ SDValue PTXTargetLowering::
if (isVarArg) llvm_unreachable("PTX does not support varargs"); if (isVarArg) llvm_unreachable("PTX does not support varargs");
MachineFunction &MF = DAG.getMachineFunction(); MachineFunction &MF = DAG.getMachineFunction();
const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
switch (CallConv) { switch (CallConv) {
@@ -206,11 +210,16 @@ SDValue PTXTargetLowering::
break; break;
} }
if (MFI->isKernel()) { // We do one of two things here:
// For kernel functions, we just need to emit the proper READ_PARAM ISDs // IsKernel || SM >= 2.0 -> Use param space for arguments
// SM < 2.0 -> Use registers for arguments
if (MFI->isKernel() || ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) {
// We just need to emit the proper READ_PARAM ISDs
for (unsigned i = 0, e = Ins.size(); i != e; ++i) { for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
assert(Ins[i].VT != MVT::i1 && "Kernels cannot take pred operands"); assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
"Kernels cannot take pred operands");
SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, Ins[i].VT, Chain, SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, Ins[i].VT, Chain,
DAG.getTargetConstant(i, MVT::i32)); DAG.getTargetConstant(i, MVT::i32));
@@ -299,31 +308,49 @@ SDValue PTXTargetLowering::
MachineFunction& MF = DAG.getMachineFunction(); MachineFunction& MF = DAG.getMachineFunction();
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
SmallVector<CCValAssign, 16> RVLocs; const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
getTargetMachine(), RVLocs, *DAG.getContext());
SDValue Flag; SDValue Flag;
CCInfo.AnalyzeReturn(Outs, RetCC_PTX); if (ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) {
// For SM 2.0+, we return arguments in the param space
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue ParamIndex = DAG.getTargetConstant(i, MVT::i32);
SDValue Ops[] = { Chain, ParamIndex, OutVals[i], Flag };
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, VTs, Ops,
Flag.getNode() ? 4 : 3);
Flag = Chain.getValue(1);
// Instead of storing a physical register in our argument list, we just
// store the total size of the parameter, in bits. The ASM printer
// knows how to process this.
MFI->addRetReg(Outs[i].VT.getStoreSizeInBits());
}
} else {
// For SM < 2.0, we return arguments in registers
SmallVector<CCValAssign, 16> RVLocs;
CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
getTargetMachine(), RVLocs, *DAG.getContext());
for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { CCInfo.AnalyzeReturn(Outs, RetCC_PTX);
CCValAssign& VA = RVLocs[i]; for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
CCValAssign& VA = RVLocs[i];
assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
unsigned Reg = VA.getLocReg(); unsigned Reg = VA.getLocReg();
DAG.getMachineFunction().getRegInfo().addLiveOut(Reg); DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);
Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag); Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);
// Guarantee that all emitted copies are stuck together, // Guarantee that all emitted copies are stuck together,
// avoiding something bad // avoiding something bad
Flag = Chain.getValue(1); Flag = Chain.getValue(1);
MFI->addRetReg(Reg); MFI->addRetReg(Reg);
}
} }
if (Flag.getNode() == 0) { if (Flag.getNode() == 0) {

View File

@@ -25,11 +25,12 @@ namespace PTXISD {
enum NodeType { enum NodeType {
FIRST_NUMBER = ISD::BUILTIN_OP_END, FIRST_NUMBER = ISD::BUILTIN_OP_END,
READ_PARAM, READ_PARAM,
STORE_PARAM,
EXIT, EXIT,
RET, RET,
COPY_ADDRESS COPY_ADDRESS
}; };
} // namespace PTXISD } // namespace PTXISD
class PTXTargetLowering : public TargetLowering { class PTXTargetLowering : public TargetLowering {
public: public:

View File

@@ -180,10 +180,15 @@ def PTXsra : SDNode<"ISD::SRA", SDTIntBinOp>;
def PTXexit def PTXexit
: SDNode<"PTXISD::EXIT", SDTNone, [SDNPHasChain]>; : SDNode<"PTXISD::EXIT", SDTNone, [SDNPHasChain]>;
def PTXret def PTXret
: SDNode<"PTXISD::RET", SDTNone, [SDNPHasChain]>; : SDNode<"PTXISD::RET", SDTNone,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
def PTXcopyaddress def PTXcopyaddress
: SDNode<"PTXISD::COPY_ADDRESS", SDTypeProfile<1, 1, []>, []>; : SDNode<"PTXISD::COPY_ADDRESS", SDTypeProfile<1, 1, []>, []>;
def PTXstoreparam
: SDNode<"PTXISD::STORE_PARAM", SDTypeProfile<0, 2, [SDTCisVT<0, i32>]>,
[SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Instruction Class Templates // Instruction Class Templates
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -816,7 +821,7 @@ defm LDc : PTX_LD_ALL<"ld.const", load_constant>;
defm LDl : PTX_LD_ALL<"ld.local", load_local>; defm LDl : PTX_LD_ALL<"ld.local", load_local>;
defm LDs : PTX_LD_ALL<"ld.shared", load_shared>; defm LDs : PTX_LD_ALL<"ld.shared", load_shared>;
// This is a special instruction that is manually inserted for kernel parameters // This is a special instruction that is manually inserted for parameters
def LDpiU16 : InstPTX<(outs RegI16:$d), (ins MEMpi:$a), def LDpiU16 : InstPTX<(outs RegI16:$d), (ins MEMpi:$a),
"ld.param.u16\t$d, [$a]", []>; "ld.param.u16\t$d, [$a]", []>;
def LDpiU32 : InstPTX<(outs RegI32:$d), (ins MEMpi:$a), def LDpiU32 : InstPTX<(outs RegI32:$d), (ins MEMpi:$a),
@@ -828,6 +833,23 @@ def LDpiF32 : InstPTX<(outs RegF32:$d), (ins MEMpi:$a),
def LDpiF64 : InstPTX<(outs RegF64:$d), (ins MEMpi:$a), def LDpiF64 : InstPTX<(outs RegF64:$d), (ins MEMpi:$a),
"ld.param.f64\t$d, [$a]", []>; "ld.param.f64\t$d, [$a]", []>;
// def STpiPred : InstPTX<(outs), (ins i1imm:$d, RegPred:$a),
// "st.param.pred\t[$d], $a",
// [(PTXstoreparam imm:$d, RegPred:$a)]>;
// def STpiU16 : InstPTX<(outs), (ins i16imm:$d, RegI16:$a),
// "st.param.u16\t[$d], $a",
// [(PTXstoreparam imm:$d, RegI16:$a)]>;
def STpiU32 : InstPTX<(outs), (ins i32imm:$d, RegI32:$a),
"st.param.u32\t[$d], $a",
[(PTXstoreparam timm:$d, RegI32:$a)]>;
// def STpiU64 : InstPTX<(outs), (ins i64imm:$d, RegI64:$a),
// "st.param.u64\t[$d], $a",
// [(PTXstoreparam imm:$d, RegI64:$a)]>;
// def STpiF32 : InstPTX<(outs), (ins MEMpi:$d, RegF32:$a),
// "st.param.f32\t[$d], $a", []>;
// def STpiF64 : InstPTX<(outs), (ins MEMpi:$d, RegF64:$a),
// "st.param.f64\t[$d], $a", []>;
// Stores // Stores
defm STg : PTX_ST_ALL<"st.global", store_global>; defm STg : PTX_ST_ALL<"st.global", store_global>;
defm STl : PTX_ST_ALL<"st.local", store_local>; defm STl : PTX_ST_ALL<"st.local", store_local>;

View File

@@ -18,7 +18,7 @@
namespace llvm { namespace llvm {
class PTXSubtarget : public TargetSubtarget { class PTXSubtarget : public TargetSubtarget {
private: public:
/** /**
* Enumeration of Shader Models supported by the back-end. * Enumeration of Shader Models supported by the back-end.
@@ -41,6 +41,8 @@ namespace llvm {
PTX_VERSION_2_3 /*< PTX Version 2.3 */ PTX_VERSION_2_3 /*< PTX Version 2.3 */
}; };
private:
/// Shader Model supported on the target GPU. /// Shader Model supported on the target GPU.
PTXShaderModelEnum PTXShaderModel; PTXShaderModelEnum PTXShaderModel;
@@ -58,8 +60,10 @@ namespace llvm {
bool Is64Bit; bool Is64Bit;
public: public:
PTXSubtarget(const std::string &TT, const std::string &FS, bool is64Bit); PTXSubtarget(const std::string &TT, const std::string &FS, bool is64Bit);
// Target architecture accessors
std::string getTargetString() const; std::string getTargetString() const;
std::string getPTXVersionString() const; std::string getPTXVersionString() const;
@@ -80,6 +84,9 @@ namespace llvm {
bool supportsPTX23() const { return PTXVersion >= PTX_VERSION_2_3; } bool supportsPTX23() const { return PTXVersion >= PTX_VERSION_2_3; }
PTXShaderModelEnum getShaderModel() const { return PTXShaderModel; }
std::string ParseSubtargetFeatures(const std::string &FS, std::string ParseSubtargetFeatures(const std::string &FS,
const std::string &CPU); const std::string &CPU);
}; // class PTXSubtarget }; // class PTXSubtarget