PTX: Add initial support for device function calls

- Calls are supported on SM 2.0+ for function with no return values

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@137125 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Justin Holewinski 2011-08-09 17:36:31 +00:00
parent 6d1fd0b979
commit 4bdd4ed564
7 changed files with 158 additions and 3 deletions

View File

@ -70,6 +70,8 @@ public:
const char *Modifier = 0);
void printPredicateOperand(const MachineInstr *MI, raw_ostream &O);
void printCall(const MachineInstr *MI, raw_ostream &O);
unsigned GetOrCreateSourceID(StringRef FileName,
StringRef DirName);
@ -242,6 +244,19 @@ void PTXAsmPrinter::EmitFunctionBodyStart() {
OutStreamer.EmitRawText(Twine(def));
}
}
unsigned Index = 1;
// Print parameter passing params
for (PTXMachineFunctionInfo::param_iterator
i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) {
std::string def = "\t.param .b";
def += utostr(*i);
def += " __ret_";
def += utostr(Index);
Index++;
def += ";";
OutStreamer.EmitRawText(Twine(def));
}
}
void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
@ -302,7 +317,11 @@ void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
printPredicateOperand(MI, OS);
// Write instruction to str
printInstruction(MI, OS);
if (MI->getOpcode() == PTX::CALL) {
printCall(MI, OS);
} else {
printInstruction(MI, OS);
}
OS << ';';
OS.flush();
@ -569,6 +588,28 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) {
}
}
void PTXAsmPrinter::
printCall(const MachineInstr *MI, raw_ostream &O) {
O << "\tcall.uni\t";
const GlobalValue *Address = MI->getOperand(2).getGlobal();
O << Address->getName() << ", (";
// (0,1) : predicate register/flag
// (2) : callee
for (unsigned i = 3; i < MI->getNumOperands(); ++i) {
//const MachineOperand& MO = MI->getOperand(i);
printReturnOperand(MI, i, O);
if (i < MI->getNumOperands()-1) {
O << ", ";
}
}
O << ")";
}
unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName,
StringRef DirName) {
// If FE did not provide a file name, then assume stdin.

View File

@ -22,6 +22,7 @@
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
@ -134,6 +135,8 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "PTXISD::EXIT";
case PTXISD::RET:
return "PTXISD::RET";
case PTXISD::CALL:
return "PTXISD::CALL";
}
}
@ -345,3 +348,49 @@ SDValue PTXTargetLowering::
return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
}
}
SDValue
PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
CallingConv::ID CallConv, bool isVarArg,
bool &isTailCall,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals,
const SmallVectorImpl<ISD::InputArg> &Ins,
DebugLoc dl, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &InVals) const {
MachineFunction& MF = DAG.getMachineFunction();
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
assert(ST.callsAreHandled() && "Calls are not handled for the target device");
// Is there a more "LLVM"-way to create a variable-length array of values?
SDValue* ops = new SDValue[OutVals.size() + 2];
ops[0] = Chain;
if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
const GlobalValue *GV = G->getGlobal();
Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
ops[1] = Callee;
} else {
assert(false && "Function must be a GlobalAddressSDNode");
}
for (unsigned i = 0; i != OutVals.size(); ++i) {
unsigned Size = OutVals[i].getValueType().getSizeInBits();
SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32);
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
Index, OutVals[i]);
ops[i+2] = Index;
}
ops[0] = Chain;
Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, ops, OutVals.size()+2);
delete [] ops;
return Chain;
}

View File

@ -28,7 +28,8 @@ namespace PTXISD {
STORE_PARAM,
EXIT,
RET,
COPY_ADDRESS
COPY_ADDRESS,
CALL
};
} // namespace PTXISD
@ -60,6 +61,16 @@ class PTXTargetLowering : public TargetLowering {
DebugLoc dl,
SelectionDAG &DAG) const;
virtual SDValue
LowerCall(SDValue Chain, SDValue Callee,
CallingConv::ID CallConv, bool isVarArg,
bool &isTailCall,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals,
const SmallVectorImpl<ISD::InputArg> &Ins,
DebugLoc dl, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &InVals) const;
virtual MVT::SimpleValueType getSetCCResultType(EVT VT) const;
private:

View File

@ -168,6 +168,18 @@ def MEMret : Operand<i32> {
let MIOperandInfo = (ops i32imm);
}
// def SDT_PTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>]>;
// def SDT_PTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
// def PTXcallseq_start : SDNode<"ISD::CALLSEQ_START", SDT_PTXCallSeqStart,
// [SDNPHasChain, SDNPOutGlue]>;
// def PTXcallseq_end : SDNode<"ISD::CALLSEQ_END", SDT_PTXCallSeqEnd,
// [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>;
def PTXcall : SDNode<"PTXISD::CALL", SDTNone,
[SDNPHasChain, SDNPVariadic, SDNPOptInGlue, SDNPOutGlue]>;
// Branch & call targets have OtherVT type.
def brtarget : Operand<OtherVT>;
def calltarget : Operand<i32>;
@ -1073,6 +1085,11 @@ let isReturn = 1, isTerminator = 1, isBarrier = 1 in {
def RET : InstPTX<(outs), (ins), "ret", [(PTXret)]>;
}
let hasSideEffects = 1 in {
def CALL : InstPTX<(outs), (ins), "call", [(PTXcall)]>;
}
///===- Spill Instructions ------------------------------------------------===//
// Special instructions used for stack spilling
def STACKSTOREI16 : InstPTX<(outs), (ins i32imm:$d, RegI16:$a),
@ -1097,6 +1114,15 @@ def STACKLOADF32 : InstPTX<(outs), (ins RegF32:$d, i32imm:$a),
def STACKLOADF64 : InstPTX<(outs), (ins RegF64:$d, i32imm:$a),
"mov.f64\t$d, s$a", []>;
// Call handling
// def ADJCALLSTACKUP :
// InstPTX<(outs), (ins i32imm:$amt1, i32imm:$amt2), "",
// [(PTXcallseq_end timm:$amt1, timm:$amt2)]>;
// def ADJCALLSTACKDOWN :
// InstPTX<(outs), (ins i32imm:$amt), "",
// [(PTXcallseq_start timm:$amt)]>;
///===- Intrinsic Instructions --------------------------------------------===//
include "PTXIntrinsicInstrInfo.td"

View File

@ -27,6 +27,7 @@ private:
bool is_kernel;
std::vector<unsigned> reg_arg, reg_local_var;
std::vector<unsigned> reg_ret;
std::vector<unsigned> call_params;
bool _isDoneAddArg;
public:
@ -56,6 +57,7 @@ public:
typedef std::vector<unsigned>::const_iterator reg_iterator;
typedef std::vector<unsigned>::const_reverse_iterator reg_reverse_iterator;
typedef std::vector<unsigned>::const_iterator ret_iterator;
typedef std::vector<unsigned>::const_iterator param_iterator;
bool argRegEmpty() const { return reg_arg.empty(); }
int getNumArg() const { return reg_arg.size(); }
@ -73,6 +75,13 @@ public:
ret_iterator retRegBegin() const { return reg_ret.begin(); }
ret_iterator retRegEnd() const { return reg_ret.end(); }
param_iterator paramBegin() const { return call_params.begin(); }
param_iterator paramEnd() const { return call_params.end(); }
unsigned getNextParam(unsigned size) {
call_params.push_back(size);
return call_params.size()-1;
}
bool isArgReg(unsigned reg) const {
return std::find(reg_arg.begin(), reg_arg.end(), reg) != reg_arg.end();
}

View File

@ -114,7 +114,12 @@ class StringRef;
(PTXTarget >= PTX_COMPUTE_2_0 && PTXTarget < PTX_LAST_COMPUTE);
}
void ParseSubtargetFeatures(StringRef CPU, StringRef FS);
bool callsAreHandled() const {
return (PTXTarget >= PTX_SM_2_0 && PTXTarget < PTX_LAST_SM) ||
(PTXTarget >= PTX_COMPUTE_2_0 && PTXTarget < PTX_LAST_COMPUTE);
}
void ParseSubtargetFeatures(StringRef CPU, StringRef FS);
}; // class PTXSubtarget
} // namespace llvm

View File

@ -0,0 +1,14 @@
; RUN: llc < %s -march=ptx32 -mattr=sm20 | FileCheck %s
define ptx_device void @test_add(float %x, float %y) {
; CHECK: ret;
%z = fadd float %x, %y
ret void
}
define ptx_device float @test_call(float %x, float %y) {
%a = fadd float %x, %y
; CHECK: call.uni test_add, (__ret_{{[0-9]+}}, __ret_{{[0-9]+}});
call void @test_add(float %a, float %y)
ret float %a
}