ptx: add ld instruction and test

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@122398 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Che-Liang Chiou 2010-12-22 10:38:51 +00:00
parent a3c44a5280
commit fc7072c3c4
7 changed files with 224 additions and 21 deletions

View File

@ -17,7 +17,8 @@
#include "PTX.h"
#include "PTXMachineFunctionInfo.h"
#include "PTXTargetMachine.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Module.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
@ -25,11 +26,13 @@
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/MC/MCStreamer.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/Target/Mangler.h"
#include "llvm/Target/TargetLoweringObjectFile.h"
#include "llvm/Target/TargetRegistry.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
@ -50,6 +53,8 @@ public:
const char *getPassName() const { return "PTX Assembly Printer"; }
bool doFinalization(Module &M);
virtual void EmitStartOfAsmFile(Module &M);
virtual bool runOnMachineFunction(MachineFunction &MF);
@ -68,6 +73,7 @@ public:
static const char *getRegisterName(unsigned RegNo);
private:
void EmitVariableDeclaration(const GlobalVariable *gv);
void EmitFunctionDeclaration();
}; // class PTXAsmPrinter
} // namespace
@ -96,11 +102,54 @@ static const char *getInstructionTypeName(const MachineInstr *MI) {
return NULL;
}
static const char *getStateSpaceName(unsigned addressSpace) {
if (addressSpace <= 255)
return "global";
// TODO Add more state spaces
llvm_unreachable("Unknown state space");
return NULL;
}
bool PTXAsmPrinter::doFinalization(Module &M) {
// XXX Temproarily remove global variables so that doFinalization() will not
// emit them again (global variables are emitted at beginning).
Module::GlobalListType &global_list = M.getGlobalList();
int i, n = global_list.size();
GlobalVariable **gv_array = new GlobalVariable* [n];
// first, back-up GlobalVariable in gv_array
i = 0;
for (Module::global_iterator I = global_list.begin(), E = global_list.end();
I != E; ++I)
gv_array[i++] = &*I;
// second, empty global_list
while (!global_list.empty())
global_list.remove(global_list.begin());
// call doFinalization
bool ret = AsmPrinter::doFinalization(M);
// now we restore global variables
for (i = 0; i < n; i ++)
global_list.insert(global_list.end(), gv_array[i]);
delete[] gv_array;
return ret;
}
void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
{
OutStreamer.EmitRawText(Twine("\t.version " + OptPTXVersion));
OutStreamer.EmitRawText(Twine("\t.target " + OptPTXTarget));
OutStreamer.AddBlankLine();
// declare global variables
for (Module::const_global_iterator i = M.global_begin(), e = M.global_end();
i != e; ++i)
EmitVariableDeclaration(i);
}
bool PTXAsmPrinter::runOnMachineFunction(MachineFunction &MF) {
@ -156,12 +205,15 @@ void PTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
default:
llvm_unreachable("<unknown operand type>");
break;
case MachineOperand::MO_Register:
OS << getRegisterName(MO.getReg());
case MachineOperand::MO_GlobalAddress:
OS << *Mang->getSymbol(MO.getGlobal());
break;
case MachineOperand::MO_Immediate:
OS << (int) MO.getImm();
break;
case MachineOperand::MO_Register:
OS << getRegisterName(MO.getReg());
break;
}
}
@ -176,6 +228,49 @@ void PTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
printOperand(MI, opNum+1, OS);
}
void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
// Check to see if this is a special global used by LLVM, if so, emit it.
if (EmitSpecialLLVMGlobal(gv))
return;
MCSymbol *gvsym = Mang->getSymbol(gv);
assert(gvsym->isUndefined() && "Cannot define a symbol twice!");
std::string decl;
// check if it is defined in some other translation unit
if (gv->isDeclaration())
decl += ".extern ";
// state space: e.g., .global
decl += ".";
decl += getStateSpaceName(gv->getType()->getAddressSpace());
decl += " ";
// alignment (optional)
unsigned alignment = gv->getAlignment();
if (alignment != 0) {
decl += ".align ";
decl += utostr(Log2_32(gv->getAlignment()));
decl += " ";
}
// TODO: add types
decl += ".s32 ";
decl += gvsym->getName();
if (ArrayType::classof(gv->getType()) || PointerType::classof(gv->getType()))
decl += "[]";
decl += ";";
OutStreamer.EmitRawText(Twine(decl));
OutStreamer.AddBlankLine();
}
void PTXAsmPrinter::EmitFunctionDeclaration() {
// The function label could have already been emitted if two symbols end up
// conflicting due to asm renaming. Detect this and emit an error.
@ -212,7 +307,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
for (int i = 0, e = MFI->getNumArg(); i != e; ++i) {
if (i != 0)
decl += ", ";
decl += ".param .s32 "; // TODO: param's type
decl += ".param .s32 "; // TODO: add types
decl += PARAM_PREFIX;
decl += utostr(i + 1);
}

View File

@ -32,6 +32,7 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
SDNode *Select(SDNode *Node);
// Complex Pattern Selectors.
bool SelectADDRrr(SDValue &Addr, SDValue &R1, SDValue &R2);
bool SelectADDRri(SDValue &Addr, SDValue &Base, SDValue &Offset);
bool SelectADDRii(SDValue &Addr, SDValue &Base, SDValue &Offset);
@ -39,8 +40,8 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
#include "PTXGenDAGISel.inc"
private:
bool isImm (const SDValue &operand);
bool SelectImm (const SDValue &operand, SDValue &imm);
bool isImm(const SDValue &operand);
bool SelectImm(const SDValue &operand, SDValue &imm);
}; // class PTXDAGToDAGISel
} // namespace
@ -60,35 +61,61 @@ SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
return SelectCode(Node);
}
// Match memory operand of the form [reg+reg] and [reg+imm]
// Match memory operand of the form [reg+reg]
bool PTXDAGToDAGISel::SelectADDRrr(SDValue &Addr, SDValue &R1, SDValue &R2) {
if (Addr.getOpcode() != ISD::ADD || Addr.getNumOperands() < 2 ||
isImm(Addr.getOperand(0)) || isImm(Addr.getOperand(1)))
return false;
R1 = Addr.getOperand(0);
R2 = Addr.getOperand(1);
return true;
}
// Match memory operand of the form [reg], [imm+reg], and [reg+imm]
bool PTXDAGToDAGISel::SelectADDRri(SDValue &Addr, SDValue &Base,
SDValue &Offset) {
if (Addr.getOpcode() != ISD::ADD) {
if (isImm(Addr))
return false;
// is [reg] but not [imm]
Base = Addr;
Offset = CurDAG->getTargetConstant(0, MVT::i32);
return true;
}
// let SelectADDRii handle the [imm+imm] case
if (Addr.getNumOperands() >= 2 &&
isImm(Addr.getOperand(0)) && isImm(Addr.getOperand(1)))
return false; // let SelectADDRii handle the [imm+imm] case
return false;
// try [reg+imm] and [imm+reg]
if (Addr.getOpcode() == ISD::ADD)
for (int i = 0; i < 2; i ++)
if (SelectImm(Addr.getOperand(1-i), Offset)) {
Base = Addr.getOperand(i);
return true;
}
for (int i = 0; i < 2; i ++)
if (SelectImm(Addr.getOperand(1-i), Offset)) {
Base = Addr.getOperand(i);
return true;
}
// okay, it's [reg+reg]
Base = Addr;
Offset = CurDAG->getTargetConstant(0, MVT::i32);
return true;
// either [reg+imm] and [imm+reg]
for (int i = 0; i < 2; i ++)
if (SelectImm(Addr.getOperand(1-i), Offset)) {
Base = Addr.getOperand(i);
return true;
}
return false;
}
// Match memory operand of the form [imm+imm] and [imm]
bool PTXDAGToDAGISel::SelectADDRii(SDValue &Addr, SDValue &Base,
SDValue &Offset) {
// is [imm+imm]?
if (Addr.getOpcode() == ISD::ADD) {
return SelectImm(Addr.getOperand(0), Base) &&
SelectImm(Addr.getOperand(1), Offset);
}
// is [imm]?
if (SelectImm(Addr, Base)) {
Offset = CurDAG->getTargetConstant(0, MVT::i32);
return true;

View File

@ -29,10 +29,22 @@ PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
addRegisterClass(MVT::i1, PTX::PredsRegisterClass);
addRegisterClass(MVT::i32, PTX::RRegs32RegisterClass);
setOperationAction(ISD::EXCEPTIONADDR, MVT::i32, Expand);
// Customize translation of memory addresses
setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
// Compute derived properties from the register classes
computeRegisterProperties();
}
SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
default: llvm_unreachable("Unimplemented operand");
case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG);
}
}
const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
switch (Opcode) {
default: llvm_unreachable("Unknown opcode");
@ -41,6 +53,18 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
}
}
//===----------------------------------------------------------------------===//
// Custom Lower Operation
//===----------------------------------------------------------------------===//
SDValue PTXTargetLowering::
LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
EVT PtrVT = getPointerTy();
DebugLoc dl = Op.getDebugLoc();
const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
return DAG.getTargetGlobalAddress(GV, dl, PtrVT);
}
//===----------------------------------------------------------------------===//
// Calling Convention Implementation
//===----------------------------------------------------------------------===//

View File

@ -38,6 +38,8 @@ class PTXTargetLowering : public TargetLowering {
virtual unsigned getFunctionAlignment(const Function *F) const {
return 2; }
virtual SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const;
virtual SDValue
LowerFormalArguments(SDValue Chain,
CallingConv::ID CallConv,
@ -55,6 +57,9 @@ class PTXTargetLowering : public TargetLowering {
const SmallVectorImpl<SDValue> &OutVals,
DebugLoc dl,
SelectionDAG &DAG) const;
private:
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
}; // class PTXTargetLowering
} // namespace llvm

View File

@ -29,10 +29,15 @@ def load_global : PatFrag<(ops node:$ptr), (load node:$ptr), [{
}]>;
// Addressing modes.
def ADDRrr : ComplexPattern<i32, 2, "SelectADDRrr", [], []>;
def ADDRri : ComplexPattern<i32, 2, "SelectADDRri", [], []>;
def ADDRii : ComplexPattern<i32, 2, "SelectADDRii", [], []>;
// Address operands
def MEMrr : Operand<i32> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops RRegs32, RRegs32);
}
def MEMri : Operand<i32> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops RRegs32, i32imm);
@ -88,6 +93,10 @@ multiclass INT3ntnc<string opcstr, SDNode opnode> {
}
multiclass PTX_LD<string opstr, RegisterClass RC, PatFrag pat_load> {
def rr : InstPTX<(outs RC:$d),
(ins MEMrr:$a),
!strconcat(opstr, ".%type\t$d, [$a]"),
[(set RC:$d, (pat_load ADDRrr:$a))]>;
def ri : InstPTX<(outs RC:$d),
(ins MEMri:$a),
!strconcat(opstr, ".%type\t$d, [$a]"),

View File

@ -11,8 +11,8 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Twine.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
#include "llvm/MC/MCCodeEmitter.h"
#include "llvm/MC/MCContext.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCInstPrinter.h"
@ -102,8 +102,7 @@ public:
virtual void SwitchSection(const MCSection *Section);
virtual void InitSections() {
}
virtual void InitSections() {}
virtual void EmitLabel(MCSymbol *Symbol);

44
test/CodeGen/PTX/ld.ll Normal file
View File

@ -0,0 +1,44 @@
; RUN: llc < %s -march=ptx | FileCheck %s
;CHECK: .extern .global .s32 array[];
@array = external global [10 x i32]
define ptx_device i32 @t1(i32* %p) {
entry:
;CHECK: ld.global.s32 r0, [r1];
%x = load i32* %p
ret i32 %x
}
define ptx_device i32 @t2(i32* %p) {
entry:
;CHECK: ld.global.s32 r0, [r1+4];
%i = getelementptr i32* %p, i32 1
%x = load i32* %i
ret i32 %x
}
define ptx_device i32 @t3(i32* %p, i32 %q) {
entry:
;CHECK: shl.b32 r0, r2, 2;
;CHECK: ld.global.s32 r0, [r1+r0];
%i = getelementptr i32* %p, i32 %q
%x = load i32* %i
ret i32 %x
}
define ptx_device i32 @t4() {
entry:
;CHECK: ld.global.s32 r0, [array];
%i = getelementptr [10 x i32]* @array, i32 0, i32 0
%x = load i32* %i
ret i32 %x
}
define ptx_device i32 @t5() {
entry:
;CHECK: ld.global.s32 r0, [array+4];
%i = getelementptr [10 x i32]* @array, i32 0, i32 1
%x = load i32* %i
ret i32 %x
}