From e953a64cb5888fbfa03f41e8258849bf979e22e5 Mon Sep 17 00:00:00 2001
From: Justin Holewinski <justin.holewinski@gmail.com>
Date: Fri, 23 Sep 2011 14:31:12 +0000
Subject: [PATCH] PTX: Start fixing function calls

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140378 91177308-0d34-0410-b5e6-96231b3b80d8
---
 lib/Target/PTX/PTXAsmPrinter.cpp   | 13 ++++++++++++-
 lib/Target/PTX/PTXISelLowering.cpp |  4 +++-
 test/CodeGen/PTX/simple-call.ll    |  3 +--
 3 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp
index 06cab0bc791..77164cac881 100644
--- a/lib/Target/PTX/PTXAsmPrinter.cpp
+++ b/lib/Target/PTX/PTXAsmPrinter.cpp
@@ -222,6 +222,7 @@ void PTXAsmPrinter::EmitFunctionBodyStart() {
   OutStreamer.EmitRawText(Twine("{"));
 
   const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
+  const PTXParamManager &PM = MFI->getParamManager();
 
   // Print register definitions
   std::string regDefs;
@@ -275,6 +276,16 @@ void PTXAsmPrinter::EmitFunctionBodyStart() {
     regDefs += ">;\n";
   }
 
+  // Local params
+  for (PTXParamManager::param_iterator i = PM.local_begin(), e = PM.local_end();
+       i != e; ++i) {
+    regDefs += "\t.param .b";
+    regDefs += utostr(PM.getParamSize(*i));
+    regDefs += " ";
+    regDefs += PM.getParamName(*i);
+    regDefs += ";\n";
+  }
+
   OutStreamer.EmitRawText(Twine(regDefs));
 
 
@@ -677,7 +688,7 @@ printCall(const MachineInstr *MI, raw_ostream &O) {
   for (unsigned i = 3; i < MI->getNumOperands(); ++i) {
     //const MachineOperand& MO = MI->getOperand(i);
 
-    printReturnOperand(MI, i, O);
+    printParamOperand(MI, i, O);
     if (i < MI->getNumOperands()-1) {
       O << ", ";
     }
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp
index 79967280344..3fdfcdf5749 100644
--- a/lib/Target/PTX/PTXISelLowering.cpp
+++ b/lib/Target/PTX/PTXISelLowering.cpp
@@ -435,6 +435,7 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
 
   MachineFunction& MF = DAG.getMachineFunction();
   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
+  PTXParamManager &PM = MFI->getParamManager();
 
   assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
          "Calls are not handled for the target device");
@@ -454,7 +455,8 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
 
   for (unsigned i = 0; i != OutVals.size(); ++i) {
     unsigned Size = OutVals[i].getValueType().getSizeInBits();
-    SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32);
+    unsigned Param = PM.addLocalParam(Size);
+    SDValue Index = DAG.getTargetConstant(Param, MVT::i32);
     Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
                         Index, OutVals[i]);
     ops[i+2] = Index;
diff --git a/test/CodeGen/PTX/simple-call.ll b/test/CodeGen/PTX/simple-call.ll
index 1647075e5e8..1e980655d3e 100644
--- a/test/CodeGen/PTX/simple-call.ll
+++ b/test/CodeGen/PTX/simple-call.ll
@@ -1,5 +1,4 @@
 ; RUN: llc < %s -march=ptx32 -mattr=sm20 | FileCheck %s
-; XFAIL: *
 
 define ptx_device void @test_add(float %x, float %y) {
 ; CHECK: ret;
@@ -9,7 +8,7 @@ define ptx_device void @test_add(float %x, float %y) {
 
 define ptx_device float @test_call(float %x, float %y) {
   %a = fadd float %x, %y
-; CHECK: call.uni test_add, (__ret_{{[0-9]+}}, __ret_{{[0-9]+}});
+; CHECK: call.uni test_add, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}});
   call void @test_add(float %a, float %y)
   ret float %a
 }