diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp index 77164cac881..d2b7c5f6b55 100644 --- a/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/lib/Target/PTX/PTXAsmPrinter.cpp @@ -677,21 +677,36 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) { void PTXAsmPrinter:: printCall(const MachineInstr *MI, raw_ostream &O) { - O << "\tcall.uni\t"; - - const GlobalValue *Address = MI->getOperand(2).getGlobal(); - O << Address->getName() << ", ("; - - // (0,1) : predicate register/flag - // (2) : callee - for (unsigned i = 3; i < MI->getNumOperands(); ++i) { - //const MachineOperand& MO = MI->getOperand(i); - - printParamOperand(MI, i, O); - if (i < MI->getNumOperands()-1) { + // The first two operands are the predicate slot + unsigned Index = 2; + while (!MI->getOperand(Index).isGlobal()) { + if (Index == 2) { + O << "("; + } else { O << ", "; } + printParamOperand(MI, Index, O); + Index++; + } + + if (Index != 2) { + O << "), "; + } + + assert(MI->getOperand(Index).isGlobal() && + "A GlobalAddress must follow the return arguments"); + + const GlobalValue *Address = MI->getOperand(Index).getGlobal(); + O << Address->getName() << ", ("; + Index++; + + while (Index < MI->getNumOperands()) { + printParamOperand(MI, Index, O); + if (Index < MI->getNumOperands()-1) { + O << ", "; + } + Index++; } O << ")"; diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index 3fdfcdf5749..053e140efe8 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -16,6 +16,7 @@ #include "PTXMachineFunctionInfo.h" #include "PTXRegisterInfo.h" #include "PTXSubtarget.h" +#include "llvm/Function.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" @@ -440,15 +441,22 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, assert(getTargetMachine().getSubtarget().callsAreHandled() && "Calls are not handled for the target device"); - // Is there a more "LLVM"-way to create a variable-length array of values? - SDValue* ops = new SDValue[OutVals.size() + 2]; + std::vector Ops; + // The layout of the ops will be [Chain, Ins, Callee, Outs] + Ops.resize(Outs.size() + Ins.size() + 2); - ops[0] = Chain; + Ops[0] = Chain; if (GlobalAddressSDNode *G = dyn_cast(Callee)) { const GlobalValue *GV = G->getGlobal(); - Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); - ops[1] = Callee; + if (const Function *F = dyn_cast(GV)) { + assert(F->getCallingConv() == CallingConv::PTX_Device && + "PTX function calls must be to PTX device functions"); + Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); + Ops[Ins.size()+1] = Callee; + } else { + assert(false && "GlobalValue is not a function"); + } } else { assert(false && "Function must be a GlobalAddressSDNode"); } @@ -459,14 +467,28 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, SDValue Index = DAG.getTargetConstant(Param, MVT::i32); Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, Index, OutVals[i]); - ops[i+2] = Index; + Ops[i+Ins.size()+2] = Index; } - ops[0] = Chain; + std::vector InParams; - Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, ops, OutVals.size()+2); + for (unsigned i = 0; i < Ins.size(); ++i) { + unsigned Size = Ins[i].VT.getStoreSizeInBits(); + unsigned Param = PM.addLocalParam(Size); + SDValue Index = DAG.getTargetConstant(Param, MVT::i32); + Ops[i+1] = Index; + InParams.push_back(Param); + } - delete [] ops; + Ops[0] = Chain; + + Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size()); + + for (unsigned i = 0; i < Ins.size(); ++i) { + SDValue Index = DAG.getTargetConstant(InParams[i], MVT::i32); + SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, Index); + InVals.push_back(Load); + } return Chain; } diff --git a/test/CodeGen/PTX/simple-call.ll b/test/CodeGen/PTX/simple-call.ll index 1e980655d3e..77ea29eae8b 100644 --- a/test/CodeGen/PTX/simple-call.ll +++ b/test/CodeGen/PTX/simple-call.ll @@ -12,3 +12,16 @@ define ptx_device float @test_call(float %x, float %y) { call void @test_add(float %a, float %y) ret float %a } + +define ptx_device float @test_compute(float %x, float %y) { +; CHECK: ret; + %z = fadd float %x, %y + ret float %z +} + +define ptx_device float @test_call_compute(float %x, float %y) { +; CHECK: call.uni (__localparam_{{[0-9]+}}), test_compute, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}}) + %z = call float @test_compute(float %x, float %y) + ret float %z +} +