ptx: add passing parameter to kernel functions

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125279 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Che-Liang Chiou 2011-02-10 12:01:24 +00:00
parent cbf023d7ec
commit 8e5d01cd6e
9 changed files with 104 additions and 62 deletions

View File

@ -38,12 +38,11 @@
using namespace llvm;
static cl::opt<std::string>
OptPTXVersion("ptx-version", cl::desc("Set PTX version"),
cl::init("1.4"));
OptPTXVersion("ptx-version", cl::desc("Set PTX version"), cl::init("1.4"));
static cl::opt<std::string>
OptPTXTarget("ptx-target", cl::desc("Set GPU target (comma-separated list)"),
cl::init("sm_10"));
cl::init("sm_10"));
namespace {
class PTXAsmPrinter : public AsmPrinter {
@ -67,6 +66,8 @@ public:
void printOperand(const MachineInstr *MI, int opNum, raw_ostream &OS);
void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
const char *Modifier = 0);
void printParamOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
const char *Modifier = 0);
// autogen'd.
void printInstruction(const MachineInstr *MI, raw_ostream &OS);
@ -231,6 +232,11 @@ void PTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
printOperand(MI, opNum+1, OS);
}
void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum,
raw_ostream &OS, const char *Modifier) {
OS << PARAM_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
}
void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
// Check to see if this is a special global used by LLVM, if so, emit it.
if (EmitSpecialLLVMGlobal(gv))

View File

@ -40,6 +40,8 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
#include "PTXGenDAGISel.inc"
private:
SDNode *SelectREAD_PARAM(SDNode *Node);
bool isImm(const SDValue &operand);
bool SelectImm(const SDValue &operand, SDValue &imm);
}; // class PTXDAGToDAGISel
@ -57,8 +59,21 @@ PTXDAGToDAGISel::PTXDAGToDAGISel(PTXTargetMachine &TM,
: SelectionDAGISel(TM, OptLevel) {}
SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
// SelectCode() is auto'gened
return SelectCode(Node);
if (Node->getOpcode() == PTXISD::READ_PARAM)
return SelectREAD_PARAM(Node);
else
return SelectCode(Node);
}
SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) {
SDValue index = Node->getOperand(1);
DebugLoc dl = Node->getDebugLoc();
if (index.getOpcode() != ISD::TargetConstant)
llvm_unreachable("READ_PARAM: index is not ISD::TargetConstant");
return PTXInstrInfo::
GetPTXMachineNode(CurDAG, PTX::LDpi, dl, MVT::i32, index);
}
// Match memory operand of the form [reg+reg]

View File

@ -47,9 +47,14 @@ SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
switch (Opcode) {
default: llvm_unreachable("Unknown opcode");
case PTXISD::EXIT: return "PTXISD::EXIT";
case PTXISD::RET: return "PTXISD::RET";
default:
llvm_unreachable("Unknown opcode");
case PTXISD::READ_PARAM:
return "PTXISD::READ_PARAM";
case PTXISD::EXIT:
return "PTXISD::EXIT";
case PTXISD::RET:
return "PTXISD::RET";
}
}
@ -86,42 +91,6 @@ struct argmap_entry {
};
} // end anonymous namespace
static SDValue lower_kernel_argument(int i,
SDValue Chain,
DebugLoc dl,
MVT::SimpleValueType VT,
argmap_entry *entry,
SelectionDAG &DAG,
unsigned *argreg) {
// TODO
llvm_unreachable("Not implemented yet");
}
static SDValue lower_device_argument(int i,
SDValue Chain,
DebugLoc dl,
MVT::SimpleValueType VT,
argmap_entry *entry,
SelectionDAG &DAG,
unsigned *argreg) {
MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
unsigned preg = *++(entry->loc); // allocate start from register 1
unsigned vreg = RegInfo.createVirtualRegister(entry->RC);
RegInfo.addLiveIn(preg, vreg);
*argreg = preg;
return DAG.getCopyFromReg(Chain, dl, vreg, VT);
}
typedef SDValue (*lower_argument_func)(int i,
SDValue Chain,
DebugLoc dl,
MVT::SimpleValueType VT,
argmap_entry *entry,
SelectionDAG &DAG,
unsigned *argreg);
SDValue PTXTargetLowering::
LowerFormalArguments(SDValue Chain,
CallingConv::ID CallConv,
@ -135,22 +104,22 @@ SDValue PTXTargetLowering::
MachineFunction &MF = DAG.getMachineFunction();
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
lower_argument_func lower_argument;
switch (CallConv) {
default:
llvm_unreachable("Unsupported calling convention");
break;
case CallingConv::PTX_Kernel:
MFI->setKernel();
lower_argument = lower_kernel_argument;
MFI->setKernel(true);
break;
case CallingConv::PTX_Device:
MFI->setKernel(false);
lower_argument = lower_device_argument;
break;
}
// Make sure we don't add argument registers twice
if (MFI->isDoneAddArg())
llvm_unreachable("cannot add argument registers twice");
// Reset argmap before allocation
for (struct argmap_entry *i = argmap, *e = argmap + array_lengthof(argmap);
i != e; ++ i)
@ -164,17 +133,27 @@ SDValue PTXTargetLowering::
if (entry == argmap + array_lengthof(argmap))
llvm_unreachable("Type of argument is not supported");
unsigned reg;
SDValue arg = lower_argument(i, Chain, dl, VT, entry, DAG, &reg);
InVals.push_back(arg);
if (MFI->isKernel() && entry->RC == PTX::PredsRegisterClass)
llvm_unreachable("cannot pass preds to kernel");
if (!MFI->isDoneAddArg())
MFI->addArgReg(reg);
MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
unsigned preg = *++(entry->loc); // allocate start from register 1
unsigned vreg = RegInfo.createVirtualRegister(entry->RC);
RegInfo.addLiveIn(preg, vreg);
MFI->addArgReg(preg);
SDValue inval;
if (MFI->isKernel())
inval = DAG.getNode(PTXISD::READ_PARAM, dl, VT, Chain,
DAG.getTargetConstant(i, MVT::i32));
else
inval = DAG.getCopyFromReg(Chain, dl, vreg, VT);
InVals.push_back(inval);
}
// Make sure we don't add argument registers twice
if (!MFI->isDoneAddArg())
MFI->doneAddArg();
MFI->doneAddArg();
return Chain;
}

View File

@ -24,6 +24,7 @@ class PTXTargetMachine;
namespace PTXISD {
enum NodeType {
FIRST_NUMBER = ISD::BUILTIN_OP_END,
READ_PARAM,
EXIT,
RET
};

View File

@ -15,6 +15,8 @@
#define PTX_INSTR_INFO_H
#include "PTXRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/Target/TargetInstrInfo.h"
namespace llvm {
@ -45,6 +47,28 @@ class PTXInstrInfo : public TargetInstrInfoImpl {
virtual bool isMoveInstr(const MachineInstr& MI,
unsigned &SrcReg, unsigned &DstReg,
unsigned &SrcSubIdx, unsigned &DstSubIdx) const;
// static helper routines
static MachineSDNode *GetPTXMachineNode(SelectionDAG *DAG, unsigned Opcode,
DebugLoc dl, EVT VT,
SDValue Op1) {
SDValue pred_reg = DAG->getRegister(0, MVT::i1);
SDValue pred_imm = DAG->getTargetConstant(0, MVT::i32);
SDValue ops[] = { Op1, pred_reg, pred_imm };
return DAG->getMachineNode(Opcode, dl, VT, ops, array_lengthof(ops));
}
static MachineSDNode *GetPTXMachineNode(SelectionDAG *DAG, unsigned Opcode,
DebugLoc dl, EVT VT,
SDValue Op1,
SDValue Op2) {
SDValue pred_reg = DAG->getRegister(0, MVT::i1);
SDValue pred_imm = DAG->getTargetConstant(0, MVT::i32);
SDValue ops[] = { Op1, Op2, pred_reg, pred_imm };
return DAG->getMachineNode(Opcode, dl, VT, ops, array_lengthof(ops));
}
}; // class PTXInstrInfo
} // namespace llvm

View File

@ -120,6 +120,10 @@ def MEMii : Operand<i32> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops i32imm, i32imm);
}
def MEMpi : Operand<i32> {
let PrintMethod = "printParamOperand";
let MIOperandInfo = (ops i32imm);
}
//===----------------------------------------------------------------------===//
// PTX Specific Node Definitions
@ -236,9 +240,13 @@ defm LDl : PTX_LD<"ld.local", RRegs32, load_local>;
defm LDp : PTX_LD<"ld.param", RRegs32, load_parameter>;
defm LDs : PTX_LD<"ld.shared", RRegs32, load_shared>;
def LDpi : InstPTX<(outs RRegs32:$d), (ins MEMpi:$a),
"ld.param.%type\t$d, [$a]", []>;
defm STg : PTX_ST<"st.global", RRegs32, store_global>;
defm STl : PTX_ST<"st.local", RRegs32, store_local>;
defm STp : PTX_ST<"st.param", RRegs32, store_parameter>;
// Store to parameter state space requires PTX 2.0 or higher?
// defm STp : PTX_ST<"st.param", RRegs32, store_parameter>;
defm STs : PTX_ST<"st.shared", RRegs32, store_shared>;
///===- Control Flow Instructions -----------------------------------------===//

View File

@ -67,7 +67,9 @@ bool PTXMFInfoExtract::runOnMachineFunction(MachineFunction &MF) {
// FIXME: This is a slow linear scanning
for (unsigned reg = PTX::NoRegister + 1; reg < PTX::NUM_TARGET_REGS; ++reg)
if (MRI.isPhysRegUsed(reg) && reg != retreg && !MFI->isArgReg(reg))
if (MRI.isPhysRegUsed(reg) &&
reg != retreg &&
(MFI->isKernel() || !MFI->isArgReg(reg)))
MFI->addLocalVarReg(reg);
// Notify MachineFunctionInfo that I've done adding local var reg

View File

@ -31,8 +31,8 @@ private:
public:
PTXMachineFunctionInfo(MachineFunction &MF)
: is_kernel(false), reg_ret(PTX::NoRegister), _isDoneAddArg(false) {
reg_arg.reserve(32);
reg_local_var.reserve(64);
reg_arg.reserve(8);
reg_local_var.reserve(32);
}
void setKernel(bool _is_kernel=true) { is_kernel = _is_kernel; }

View File

@ -3,5 +3,12 @@
define ptx_kernel void @t1() {
; CHECK: exit;
; CHECK-NOT: ret;
ret void
ret void
}
define ptx_kernel void @t2(i32* %p, i32 %x) {
store i32 %x, i32* %p
; CHECK: exit;
; CHECK-NOT: ret;
ret void
}