From 88d3367baa066b4924a9303291aee084c154fff1 Mon Sep 17 00:00:00 2001
From: Che-Liang Chiou <clchiou@gmail.com>
Date: Fri, 18 Mar 2011 11:08:52 +0000
Subject: [PATCH] ptx: add unconditional and conditional branch

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@127873 91177308-0d34-0410-b5e6-96231b3b80d8
---
 lib/Target/PTX/PTXAsmPrinter.cpp    | 19 +++--------------
 lib/Target/PTX/PTXISelDAGToDAG.cpp  | 33 +++++++++++++++++++++++++----
 lib/Target/PTX/PTXISelLowering.cpp  | 11 ++++++++--
 lib/Target/PTX/PTXInstrInfo.td      | 18 +++++++++++++---
 lib/Target/PTX/PTXMCAsmStreamer.cpp |  2 +-
 test/CodeGen/PTX/bra.ll             | 21 ++++++++++++++++++
 6 files changed, 78 insertions(+), 26 deletions(-)
 create mode 100644 test/CodeGen/PTX/bra.ll

diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp
index 0b07f749b3e..19a699e44fe 100644
--- a/lib/Target/PTX/PTXAsmPrinter.cpp
+++ b/lib/Target/PTX/PTXAsmPrinter.cpp
@@ -91,17 +91,6 @@ static const char *getRegisterTypeName(unsigned RegNo) {
   return NULL;
 }
 
-static const char *getInstructionTypeName(const MachineInstr *MI) {
-  for (int i = 0, e = MI->getNumOperands(); i != e; ++i) {
-    const MachineOperand &MO = MI->getOperand(i);
-    if (MO.getType() == MachineOperand::MO_Register)
-      return getRegisterTypeName(MO.getReg());
-  }
-
-  llvm_unreachable("No reg operand found in instruction!");
-  return NULL;
-}
-
 static const char *getStateSpaceName(unsigned addressSpace) {
   switch (addressSpace) {
   default: llvm_unreachable("Unknown state space");
@@ -221,11 +210,6 @@ void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
   OS << ';';
   OS.flush();
 
-  // Replace "%type" if found
-  size_t pos;
-  if ((pos = str.find("%type")) != std::string::npos)
-    str.replace(pos, /*strlen("%type")==*/5, getInstructionTypeName(MI));
-
   StringRef strref = StringRef(str);
   OutStreamer.EmitRawText(strref);
 }
@@ -244,6 +228,9 @@ void PTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
     case MachineOperand::MO_Immediate:
       OS << (int) MO.getImm();
       break;
+    case MachineOperand::MO_MachineBasicBlock:
+      OS << *MO.getMBB()->getSymbol();
+      break;
     case MachineOperand::MO_Register:
       OS << getRegisterName(MO.getReg());
       break;
diff --git a/lib/Target/PTX/PTXISelDAGToDAG.cpp b/lib/Target/PTX/PTXISelDAGToDAG.cpp
index fe2d25a3b43..5be1224dcfd 100644
--- a/lib/Target/PTX/PTXISelDAGToDAG.cpp
+++ b/lib/Target/PTX/PTXISelDAGToDAG.cpp
@@ -43,6 +43,10 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
   private:
     SDNode *SelectREAD_PARAM(SDNode *Node);
 
+    // We need this only because we can't match intruction BRAdp
+    // pattern (PTXbrcond bb:$d, ...) in PTXInstrInfo.td
+    SDNode *SelectBRCOND(SDNode *Node);
+
     bool isImm(const SDValue &operand);
     bool SelectImm(const SDValue &operand, SDValue &imm);
 
@@ -62,10 +66,14 @@ PTXDAGToDAGISel::PTXDAGToDAGISel(PTXTargetMachine &TM,
   : SelectionDAGISel(TM, OptLevel) {}
 
 SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
-  if (Node->getOpcode() == PTXISD::READ_PARAM)
-    return SelectREAD_PARAM(Node);
-  else
-    return SelectCode(Node);
+  switch (Node->getOpcode()) {
+    case PTXISD::READ_PARAM:
+      return SelectREAD_PARAM(Node);
+    case ISD::BRCOND:
+      return SelectBRCOND(Node);
+    default:
+      return SelectCode(Node);
+  }
 }
 
 SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) {
@@ -99,6 +107,23 @@ SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) {
     GetPTXMachineNode(CurDAG, opcode, dl, Node->getValueType(0), index);
 }
 
+SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) {
+  assert(Node->getNumOperands() >= 3);
+
+  SDValue Chain  = Node->getOperand(0);
+  SDValue Pred   = Node->getOperand(1);
+  SDValue Target = Node->getOperand(2); // branch target
+  SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32);
+  DebugLoc dl = Node->getDebugLoc();
+
+  assert(Target.getOpcode()  == ISD::BasicBlock);
+  assert(Pred.getValueType() == MVT::i1);
+
+  // Emit BRAdp
+  SDValue Ops[] = { Target, Pred, PredOp, Chain };
+  return CurDAG->getMachineNode(PTX::BRAdp, dl, MVT::Other, Ops, 4);
+}
+
 // Match memory operand of the form [reg+reg]
 bool PTXDAGToDAGISel::SelectADDRrr(SDValue &Addr, SDValue &R1, SDValue &R2) {
   if (Addr.getOpcode() != ISD::ADD || Addr.getNumOperands() < 2 ||
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp
index 1b4b7e6625d..159a27a429f 100644
--- a/lib/Target/PTX/PTXISelLowering.cpp
+++ b/lib/Target/PTX/PTXISelLowering.cpp
@@ -42,14 +42,21 @@ PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
   // Customize translation of memory addresses
   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
 
+  // Expand BR_CC into BRCOND
+  setOperationAction(ISD::BR_CC, MVT::Other, Expand);
+
   // Compute derived properties from the register classes
   computeRegisterProperties();
 }
 
 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
-    default:                 llvm_unreachable("Unimplemented operand");
-    case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG);
+    default:
+      llvm_unreachable("Unimplemented operand");
+    case ISD::GlobalAddress:
+      return LowerGlobalAddress(Op, DAG);
+    case ISD::BRCOND:
+      return LowerGlobalAddress(Op, DAG);
   }
 }
 
diff --git a/lib/Target/PTX/PTXInstrInfo.td b/lib/Target/PTX/PTXInstrInfo.td
index dc96914ba6e..42671b6e69c 100644
--- a/lib/Target/PTX/PTXInstrInfo.td
+++ b/lib/Target/PTX/PTXInstrInfo.td
@@ -37,7 +37,6 @@ def DoesNotSupportPTX21 : Predicate<"!getSubtarget().supportsPTX21()">;
 def SupportsPTX22       : Predicate<"getSubtarget().supportsPTX22()">;
 def DoesNotSupportPTX22 : Predicate<"!getSubtarget().supportsPTX22()">;
 
-
 //===----------------------------------------------------------------------===//
 // Instruction Pattern Stuff
 //===----------------------------------------------------------------------===//
@@ -135,7 +134,6 @@ def ADDRri64 : ComplexPattern<i64, 2, "SelectADDRri", [], []>;
 def ADDRii32 : ComplexPattern<i32, 2, "SelectADDRii", [], []>;
 def ADDRii64 : ComplexPattern<i64, 2, "SelectADDRii", [], []>;
 
-
 // Address operands
 def MEMri32 : Operand<i32> {
   let PrintMethod = "printMemOperand";
@@ -160,6 +158,9 @@ def MEMpi : Operand<i32> {
   let MIOperandInfo = (ops i32imm);
 }
 
+// Branch & call targets have OtherVT type.
+def brtarget   : Operand<OtherVT>;
+def calltarget : Operand<i32>;
 
 //===----------------------------------------------------------------------===//
 // PTX Specific Node Definitions
@@ -281,7 +282,6 @@ multiclass PTX_LOGIC<string opcstr, SDNode opnode> {
                      [(set RRegu64:$d, (opnode RRegu64:$a, imm:$b))]>;
 }
 
-// no %type directive, non-communtable
 multiclass INT3ntnc<string opcstr, SDNode opnode> {
   def rr : InstPTX<(outs RRegu32:$d),
                    (ins RRegu32:$a, RRegu32:$b),
@@ -566,6 +566,18 @@ def CVT_u32_pred
 
 ///===- Control Flow Instructions -----------------------------------------===//
 
+let isBranch = 1, isTerminator = 1 in {
+  def BRAd
+    : InstPTX<(outs), (ins brtarget:$d), "bra\t$d", [(br bb:$d)]>;
+
+  // FIXME: should be able to write a pattern for brcond, but can't use
+  //        a two-value operand where a dag node expects two operands. :(
+  // NOTE: ARM & PowerPC backend also report the same problem
+  def BRAdp
+    : InstPTX<(outs), (ins brtarget:$d), "bra\t$d",
+              [/*(brcond bb:$d, Preds:$p, i32imm:$c)*/]>;
+}
+
 let isReturn = 1, isTerminator = 1, isBarrier = 1 in {
   def EXIT : InstPTX<(outs), (ins), "exit", [(PTXexit)]>;
   def RET  : InstPTX<(outs), (ins), "ret",  [(PTXret)]>;
diff --git a/lib/Target/PTX/PTXMCAsmStreamer.cpp b/lib/Target/PTX/PTXMCAsmStreamer.cpp
index 1c5d418c30d..cf743e48c99 100644
--- a/lib/Target/PTX/PTXMCAsmStreamer.cpp
+++ b/lib/Target/PTX/PTXMCAsmStreamer.cpp
@@ -233,7 +233,7 @@ void PTXMCAsmStreamer::ChangeSection(const MCSection *Section) {
 void PTXMCAsmStreamer::EmitLabel(MCSymbol *Symbol) {
   assert(Symbol->isUndefined() && "Cannot define a symbol twice!");
   assert(!Symbol->isVariable() && "Cannot emit a variable symbol!");
-  assert(getCurrentSection() && "Cannot emit before setting section!");
+  //assert(getCurrentSection() && "Cannot emit before setting section!");
 
   OS << *Symbol << MAI.getLabelSuffix();
   EmitEOL();
diff --git a/test/CodeGen/PTX/bra.ll b/test/CodeGen/PTX/bra.ll
new file mode 100644
index 00000000000..07916974b5b
--- /dev/null
+++ b/test/CodeGen/PTX/bra.ll
@@ -0,0 +1,21 @@
+; RUN: llc < %s -march=ptx | FileCheck %s
+
+define ptx_device void @test_bra_direct() {
+; CHECK: bra $L__BB0_1;
+entry:
+	br label %loop
+loop:
+	br label %loop
+}
+
+define ptx_device i32 @test_bra_cond_direct(i32 %x, i32 %y) {
+entry:
+	%p = icmp ugt i32 %x, %y
+	br i1 %p, label %clause.if, label %clause.else
+clause.if:
+; CHECK: mov.u32 r0, r1
+	ret i32 %x
+clause.else:
+; CHECK: mov.u32 r0, r2
+	ret i32 %y
+}