PTX: Attempt to cleanup/unify the handling of FP rounding modes. This requires

us to manually provide Pat<> definitions for all FP instruction patterns.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140849 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Justin Holewinski 2011-09-30 12:54:43 +00:00
parent 10a11ecb59
commit c90e149ee4
9 changed files with 469 additions and 195 deletions

View File

@ -12,6 +12,7 @@ add_llvm_target(PTXCodeGen
PTXISelDAGToDAG.cpp
PTXISelLowering.cpp
PTXInstrInfo.cpp
PTXFPRoundingModePass.cpp
PTXFrameLowering.cpp
PTXMCAsmStreamer.cpp
PTXMCInstLower.cpp

View File

@ -19,6 +19,7 @@
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
@ -146,4 +147,46 @@ void PTXInstPrinter::printMemOperand(const MCInst *MI, unsigned OpNo,
printOperand(MI, OpNo+1, O);
}
void PTXInstPrinter::printRoundingMode(const MCInst *MI, unsigned OpNo,
raw_ostream &O) {
const MCOperand &Op = MI->getOperand(OpNo);
assert (Op.isImm() && "Rounding modes must be immediate values");
switch (Op.getImm()) {
default:
llvm_unreachable("Unknown rounding mode!");
case PTXRoundingMode::RndDefault:
llvm_unreachable("FP rounding-mode pass did not handle instruction!");
break;
case PTXRoundingMode::RndNone:
// Do not print anything.
break;
case PTXRoundingMode::RndNearestEven:
O << ".rn";
break;
case PTXRoundingMode::RndTowardsZero:
O << ".rz";
break;
case PTXRoundingMode::RndNegInf:
O << ".rm";
break;
case PTXRoundingMode::RndPosInf:
O << ".rp";
break;
case PTXRoundingMode::RndApprox:
O << ".approx";
break;
case PTXRoundingMode::RndNearestEvenInt:
O << ".rni";
break;
case PTXRoundingMode::RndTowardsZeroInt:
O << ".rzi";
break;
case PTXRoundingMode::RndNegInfInt:
O << ".rmi";
break;
case PTXRoundingMode::RndPosInfInt:
O << ".rpi";
break;
}
}

View File

@ -39,6 +39,7 @@ public:
void printCall(const MCInst *MI, raw_ostream &O);
void printOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O);
void printMemOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O);
void printRoundingMode(const MCInst *MI, unsigned OpNo, raw_ostream &O);
};
}

View File

@ -35,6 +35,26 @@ namespace llvm {
PRED_NONE = 2
};
} // namespace PTX
/// Namespace to hold all target-specific flags.
namespace PTXRoundingMode {
// Instruction Flags
enum {
// Rounding Mode Flags
RndMask = 15,
RndDefault = 0, // ---
RndNone = 1, // <NONE>
RndNearestEven = 2, // .rn
RndTowardsZero = 3, // .rz
RndNegInf = 4, // .rm
RndPosInf = 5, // .rp
RndApprox = 6, // .approx
RndNearestEvenInt = 7, // .rni
RndTowardsZeroInt = 8, // .rzi
RndNegInfInt = 9, // .rmi
RndPosInfInt = 10 // .rpi
};
} // namespace PTXII
} // namespace llvm
#endif

View File

@ -31,6 +31,9 @@ namespace llvm {
FunctionPass *createPTXMFInfoExtract(PTXTargetMachine &TM,
CodeGenOpt::Level OptLevel);
FunctionPass *createPTXFPRoundingModePass(PTXTargetMachine &TM,
CodeGenOpt::Level OptLevel);
FunctionPass *createPTXRegisterAllocator();
void LowerPTXMachineInstrToMCInst(const MachineInstr *MI, MCInst &OutMI,

View File

@ -0,0 +1,155 @@
//===-- PTXFPRoundingModePass.cpp - Assign rounding modes pass ------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This file defines a machine function pass that sets appropriate FP rounding
// modes for all relevant instructions.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "ptx-fp-rounding-mode"
#include "PTX.h"
#include "PTXTargetMachine.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
// NOTE: PTXFPRoundingModePass should be executed just before emission.
namespace llvm {
/// PTXFPRoundingModePass - Pass to assign appropriate FP rounding modes to
/// all FP instructions. Essentially, this pass just looks for all FP
/// instructions that have a rounding mode set to RndDefault, and sets an
/// appropriate rounding mode based on the target device.
///
class PTXFPRoundingModePass : public MachineFunctionPass {
private:
static char ID;
PTXTargetMachine& TargetMachine;
public:
PTXFPRoundingModePass(PTXTargetMachine &TM, CodeGenOpt::Level OptLevel)
: MachineFunctionPass(ID),
TargetMachine(TM) {}
virtual bool runOnMachineFunction(MachineFunction &MF);
virtual const char *getPassName() const {
return "PTX FP Rounding Mode Pass";
}
private:
void processInstruction(MachineInstr &MI);
}; // class PTXFPRoundingModePass
} // namespace llvm
using namespace llvm;
char PTXFPRoundingModePass::ID = 0;
bool PTXFPRoundingModePass::runOnMachineFunction(MachineFunction &MF) {
// Look at each basic block
for (MachineFunction::iterator bbi = MF.begin(), bbe = MF.end(); bbi != bbe;
++bbi) {
MachineBasicBlock &MBB = *bbi;
// Look at each instruction
for (MachineBasicBlock::iterator ii = MBB.begin(), ie = MBB.end();
ii != ie; ++ii) {
MachineInstr &MI = *ii;
processInstruction(MI);
}
}
return false;
}
void PTXFPRoundingModePass::processInstruction(MachineInstr &MI) {
// If the instruction has a rounding mode set to RndDefault, then assign an
// appropriate rounding mode based on the target device.
const PTXSubtarget& ST = TargetMachine.getSubtarget<PTXSubtarget>();
switch (MI.getOpcode()) {
case PTX::FADDrr32:
case PTX::FADDri32:
case PTX::FADDrr64:
case PTX::FADDri64:
case PTX::FSUBrr32:
case PTX::FSUBri32:
case PTX::FSUBrr64:
case PTX::FSUBri64:
case PTX::FMULrr32:
case PTX::FMULri32:
case PTX::FMULrr64:
case PTX::FMULri64:
if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven);
}
break;
case PTX::FNEGrr32:
case PTX::FNEGri32:
case PTX::FNEGrr64:
case PTX::FNEGri64:
if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
MI.getOperand(1).setImm(PTXRoundingMode::RndNone);
}
break;
case PTX::FDIVrr32:
case PTX::FDIVri32:
case PTX::FDIVrr64:
case PTX::FDIVri64:
if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
if (ST.fdivNeedsRoundingMode())
MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven);
else
MI.getOperand(1).setImm(PTXRoundingMode::RndNone);
}
break;
case PTX::FMADrrr32:
case PTX::FMADrri32:
case PTX::FMADrii32:
case PTX::FMADrrr64:
case PTX::FMADrri64:
case PTX::FMADrii64:
if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
if (ST.fmadNeedsRoundingMode())
MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven);
else
MI.getOperand(1).setImm(PTXRoundingMode::RndNone);
}
break;
case PTX::FSQRTrr32:
case PTX::FSQRTri32:
case PTX::FSQRTrr64:
case PTX::FSQRTri64:
if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven);
}
break;
case PTX::FSINrr32:
case PTX::FSINri32:
case PTX::FSINrr64:
case PTX::FSINri64:
case PTX::FCOSrr32:
case PTX::FCOSri32:
case PTX::FCOSrr64:
case PTX::FCOSri64:
if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
MI.getOperand(1).setImm(PTXRoundingMode::RndApprox);
}
break;
}
}
FunctionPass *llvm::createPTXFPRoundingModePass(PTXTargetMachine &TM,
CodeGenOpt::Level OptLevel) {
return new PTXFPRoundingModePass(TM, OptLevel);
}

View File

@ -7,12 +7,39 @@
//
//===----------------------------------------------------------------------===//
// Rounding Mode Specifier
/*class RoundingMode<bits<3> val> {
bits<3> Value = val;
}
def RndDefault : RoundingMode<0>;
def RndNearestEven : RoundingMode<1>;
def RndNearestZero : RoundingMode<2>;
def RndNegInf : RoundingMode<3>;
def RndPosInf : RoundingMode<4>;
def RndApprox : RoundingMode<5>;*/
// Rounding Mode Operand
def RndMode : Operand<i32> {
let PrintMethod = "printRoundingMode";
}
def RndDefault : PatLeaf<(i32 0)>;
// PTX Predicate operand, default to (0, 0) = (zero-reg, none).
// Leave PrintMethod empty; predicate printing is defined elsewhere.
def pred : PredicateOperand<OtherVT, (ops RegPred, i32imm),
(ops (i1 zero_reg), (i32 2))>;
def RndModeOperand : Operand<OtherVT> {
let MIOperandInfo = (ops i32imm);
}
// Instruction Types
let Namespace = "PTX" in {
class InstPTX<dag oops, dag iops, string asmstr, list<dag> pattern>
: Instruction {
dag OutOperandList = oops;

View File

@ -80,75 +80,67 @@ def PTXcopyaddress
// Instruction Class Templates
//===----------------------------------------------------------------------===//
// For floating-point instructions, we cannot just embed the pattern into the
// instruction definition since we need to muck around with the rounding mode,
// and I do not know how to insert constants into instructions directly from
// pattern matches.
//===- Floating-Point Instructions - 2 Operand Form -----------------------===//
multiclass PTX_FLOAT_2OP<string opcstr, SDNode opnode> {
multiclass PTX_FLOAT_2OP<string opcstr> {
def rr32 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a),
!strconcat(opcstr, ".f32\t$d, $a"),
[(set RegF32:$d, (opnode RegF32:$a))]>;
(ins RndMode:$r, RegF32:$a),
!strconcat(opcstr, "$r.f32\t$d, $a"), []>;
def ri32 : InstPTX<(outs RegF32:$d),
(ins f32imm:$a),
!strconcat(opcstr, ".f32\t$d, $a"),
[(set RegF32:$d, (opnode fpimm:$a))]>;
(ins RndMode:$r, f32imm:$a),
!strconcat(opcstr, "$r.f32\t$d, $a"), []>;
def rr64 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a),
!strconcat(opcstr, ".f64\t$d, $a"),
[(set RegF64:$d, (opnode RegF64:$a))]>;
(ins RndMode:$r, RegF64:$a),
!strconcat(opcstr, "$r.f64\t$d, $a"), []>;
def ri64 : InstPTX<(outs RegF64:$d),
(ins f64imm:$a),
!strconcat(opcstr, ".f64\t$d, $a"),
[(set RegF64:$d, (opnode fpimm:$a))]>;
(ins RndMode:$r, f64imm:$a),
!strconcat(opcstr, "$r.f64\t$d, $a"), []>;
}
//===- Floating-Point Instructions - 3 Operand Form -----------------------===//
multiclass PTX_FLOAT_3OP<string opcstr, SDNode opnode> {
multiclass PTX_FLOAT_3OP<string opcstr> {
def rr32 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a, RegF32:$b),
!strconcat(opcstr, ".f32\t$d, $a, $b"),
[(set RegF32:$d, (opnode RegF32:$a, RegF32:$b))]>;
(ins RndMode:$r, RegF32:$a, RegF32:$b),
!strconcat(opcstr, "$r.f32\t$d, $a, $b"), []>;
def ri32 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a, f32imm:$b),
!strconcat(opcstr, ".f32\t$d, $a, $b"),
[(set RegF32:$d, (opnode RegF32:$a, fpimm:$b))]>;
(ins RndMode:$r, RegF32:$a, f32imm:$b),
!strconcat(opcstr, "$r.f32\t$d, $a, $b"), []>;
def rr64 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a, RegF64:$b),
!strconcat(opcstr, ".f64\t$d, $a, $b"),
[(set RegF64:$d, (opnode RegF64:$a, RegF64:$b))]>;
(ins RndMode:$r, RegF64:$a, RegF64:$b),
!strconcat(opcstr, "$r.f64\t$d, $a, $b"), []>;
def ri64 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a, f64imm:$b),
!strconcat(opcstr, ".f64\t$d, $a, $b"),
[(set RegF64:$d, (opnode RegF64:$a, fpimm:$b))]>;
(ins RndMode:$r, RegF64:$a, f64imm:$b),
!strconcat(opcstr, "$r.f64\t$d, $a, $b"), []>;
}
//===- Floating-Point Instructions - 4 Operand Form -----------------------===//
multiclass PTX_FLOAT_4OP<string opcstr, SDNode opnode1, SDNode opnode2> {
multiclass PTX_FLOAT_4OP<string opcstr> {
def rrr32 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a, RegF32:$b, RegF32:$c),
!strconcat(opcstr, ".f32\t$d, $a, $b, $c"),
[(set RegF32:$d, (opnode2 (opnode1 RegF32:$a,
RegF32:$b),
RegF32:$c))]>;
(ins RndMode:$r, RegF32:$a, RegF32:$b, RegF32:$c),
!strconcat(opcstr, "$r.f32\t$d, $a, $b, $c"), []>;
def rri32 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a, RegF32:$b, f32imm:$c),
!strconcat(opcstr, ".f32\t$d, $a, $b, $c"),
[(set RegF32:$d, (opnode2 (opnode1 RegF32:$a,
RegF32:$b),
fpimm:$c))]>;
(ins RndMode:$r, RegF32:$a, RegF32:$b, f32imm:$c),
!strconcat(opcstr, "$r.f32\t$d, $a, $b, $c"), []>;
def rii32 : InstPTX<(outs RegF32:$d),
(ins RndMode:$r, RegF32:$a, f32imm:$b, f32imm:$c),
!strconcat(opcstr, "$r.f32\t$d, $a, $b, $c"), []>;
def rrr64 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a, RegF64:$b, RegF64:$c),
!strconcat(opcstr, ".f64\t$d, $a, $b, $c"),
[(set RegF64:$d, (opnode2 (opnode1 RegF64:$a,
RegF64:$b),
RegF64:$c))]>;
(ins RndMode:$r, RegF64:$a, RegF64:$b, RegF64:$c),
!strconcat(opcstr, "$r.f64\t$d, $a, $b, $c"), []>;
def rri64 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a, RegF64:$b, f64imm:$c),
!strconcat(opcstr, ".f64\t$d, $a, $b, $c"),
[(set RegF64:$d, (opnode2 (opnode1 RegF64:$a,
RegF64:$b),
fpimm:$c))]>;
(ins RndMode:$r, RegF64:$a, RegF64:$b, f64imm:$c),
!strconcat(opcstr, "$r.f64\t$d, $a, $b, $c"), []>;
def rii64 : InstPTX<(outs RegF64:$d),
(ins RndMode:$r, RegF64:$a, f64imm:$b, f64imm:$c),
!strconcat(opcstr, "$r.f64\t$d, $a, $b, $c"), []>;
}
multiclass INT3<string opcstr, SDNode opnode> {
//===- Integer Instructions - 3 Operand Form ------------------------------===//
multiclass PTX_INT3<string opcstr, SDNode opnode> {
def rr16 : InstPTX<(outs RegI16:$d),
(ins RegI16:$a, RegI16:$b),
!strconcat(opcstr, ".u16\t$d, $a, $b"),
@ -175,6 +167,7 @@ multiclass INT3<string opcstr, SDNode opnode> {
[(set RegI64:$d, (opnode RegI64:$a, imm:$b))]>;
}
//===- Bitwise Logic Instructions - 3 Operand Form ------------------------===//
multiclass PTX_LOGIC<string opcstr, SDNode opnode> {
def ripreds : InstPTX<(outs RegPred:$d),
(ins RegPred:$a, i1imm:$b),
@ -210,7 +203,8 @@ multiclass PTX_LOGIC<string opcstr, SDNode opnode> {
[(set RegI64:$d, (opnode RegI64:$a, imm:$b))]>;
}
multiclass INT3ntnc<string opcstr, SDNode opnode> {
//===- Integer Shift Instructions - 3 Operand Form ------------------------===//
multiclass PTX_INT3ntnc<string opcstr, SDNode opnode> {
def rr16 : InstPTX<(outs RegI16:$d),
(ins RegI16:$a, RegI16:$b),
!strconcat(opcstr, "16\t$d, $a, $b"),
@ -249,6 +243,7 @@ multiclass INT3ntnc<string opcstr, SDNode opnode> {
[(set RegI64:$d, (opnode imm:$a, RegI64:$b))]>;
}
//===- Set Predicate Instructions (Int) - 3/4 Operand Forms ---------------===//
multiclass PTX_SETP_I<RegisterClass RC, string regclsname, Operand immcls,
CondCode cmp, string cmpstr> {
// TODO support 5-operand format: p|q, a, b, c
@ -333,6 +328,7 @@ multiclass PTX_SETP_I<RegisterClass RC, string regclsname, Operand immcls,
(not RegPred:$c)))]>;
}
//===- Set Predicate Instructions (FP) - 3/4 Operand Form -----------------===//
multiclass PTX_SETP_FP<RegisterClass RC, string regclsname, Operand immcls,
CondCode ucmp, CondCode ocmp, string cmpstr> {
// TODO support 5-operand format: p|q, a, b, c
@ -432,6 +428,7 @@ multiclass PTX_SETP_FP<RegisterClass RC, string regclsname, Operand immcls,
(not RegPred:$c)))]>;
}
//===- Select Predicate Instructions - 4 Operand Form ---------------------===//
multiclass PTX_SELP<RegisterClass RC, string regclsname, Operand immcls,
SDNode immnode> {
def rr
@ -456,118 +453,60 @@ multiclass PTX_SELP<RegisterClass RC, string regclsname, Operand immcls,
///===- Integer Arithmetic Instructions -----------------------------------===//
defm ADD : INT3<"add", add>;
defm SUB : INT3<"sub", sub>;
defm MUL : INT3<"mul.lo", mul>; // FIXME: Allow 32x32 -> 64 multiplies
defm DIV : INT3<"div", udiv>;
defm REM : INT3<"rem", urem>;
defm ADD : PTX_INT3<"add", add>;
defm SUB : PTX_INT3<"sub", sub>;
defm MUL : PTX_INT3<"mul.lo", mul>; // FIXME: Allow 32x32 -> 64 multiplies
defm DIV : PTX_INT3<"div", udiv>;
defm REM : PTX_INT3<"rem", urem>;
///===- Floating-Point Arithmetic Instructions ----------------------------===//
// Standard Unary Operations
defm FNEG : PTX_FLOAT_2OP<"neg", fneg>;
// FNEG
defm FNEG : PTX_FLOAT_2OP<"neg">;
// Standard Binary Operations
defm FADD : PTX_FLOAT_3OP<"add.rn", fadd>;
defm FSUB : PTX_FLOAT_3OP<"sub.rn", fsub>;
defm FMUL : PTX_FLOAT_3OP<"mul.rn", fmul>;
// For floating-point division:
// SM_13+ defaults to .rn for f32 and f64,
// SM10 must *not* provide a rounding
// TODO:
// - Allow user selection of rounding modes for fdiv
// - Add support for -prec-div=false (.approx)
def FDIVrr32SM13 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a, RegF32:$b),
"div.rn.f32\t$d, $a, $b",
[(set RegF32:$d, (fdiv RegF32:$a, RegF32:$b))]>,
Requires<[FDivNeedsRoundingMode]>;
def FDIVri32SM13 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a, f32imm:$b),
"div.rn.f32\t$d, $a, $b",
[(set RegF32:$d, (fdiv RegF32:$a, fpimm:$b))]>,
Requires<[FDivNeedsRoundingMode]>;
def FDIVrr32SM10 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a, RegF32:$b),
"div.f32\t$d, $a, $b",
[(set RegF32:$d, (fdiv RegF32:$a, RegF32:$b))]>,
Requires<[FDivNoRoundingMode]>;
def FDIVri32SM10 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a, f32imm:$b),
"div.f32\t$d, $a, $b",
[(set RegF32:$d, (fdiv RegF32:$a, fpimm:$b))]>,
Requires<[FDivNoRoundingMode]>;
def FDIVrr64SM13 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a, RegF64:$b),
"div.rn.f64\t$d, $a, $b",
[(set RegF64:$d, (fdiv RegF64:$a, RegF64:$b))]>,
Requires<[FDivNeedsRoundingMode]>;
def FDIVri64SM13 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a, f64imm:$b),
"div.rn.f64\t$d, $a, $b",
[(set RegF64:$d, (fdiv RegF64:$a, fpimm:$b))]>,
Requires<[FDivNeedsRoundingMode]>;
def FDIVrr64SM10 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a, RegF64:$b),
"div.f64\t$d, $a, $b",
[(set RegF64:$d, (fdiv RegF64:$a, RegF64:$b))]>,
Requires<[FDivNoRoundingMode]>;
def FDIVri64SM10 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a, f64imm:$b),
"div.f64\t$d, $a, $b",
[(set RegF64:$d, (fdiv RegF64:$a, fpimm:$b))]>,
Requires<[FDivNoRoundingMode]>;
defm FADD : PTX_FLOAT_3OP<"add">;
defm FSUB : PTX_FLOAT_3OP<"sub">;
defm FMUL : PTX_FLOAT_3OP<"mul">;
defm FDIV : PTX_FLOAT_3OP<"div">;
// Multi-operation hybrid instructions
defm FMAD : PTX_FLOAT_4OP<"mad">, Requires<[SupportsFMA]>;
// The selection of mad/fma is tricky. In some cases, they are the *same*
// instruction, but in other cases we may prefer one or the other. Also,
// different PTX versions differ on whether rounding mode flags are required.
// In the short term, mad is supported on all PTX versions and we use a
// default rounding mode no matter what shader model or PTX version.
// TODO: Allow the rounding mode to be selectable through llc.
defm FMADSM13 : PTX_FLOAT_4OP<"mad.rn", fmul, fadd>,
Requires<[FMadNeedsRoundingMode, SupportsFMA]>;
defm FMAD : PTX_FLOAT_4OP<"mad", fmul, fadd>,
Requires<[FMadNoRoundingMode, SupportsFMA]>;
///===- Floating-Point Intrinsic Instructions -----------------------------===//
def FSQRT32 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a),
"sqrt.rn.f32\t$d, $a",
[(set RegF32:$d, (fsqrt RegF32:$a))]>;
// SQRT
def FSQRTrr32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, RegF32:$a),
"sqrt$r.f32\t$d, $a", []>;
def FSQRTri32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, f32imm:$a),
"sqrt$r.f32\t$d, $a", []>;
def FSQRTrr64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, RegF64:$a),
"sqrt$r.f64\t$d, $a", []>;
def FSQRTri64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, f64imm:$a),
"sqrt$r.f64\t$d, $a", []>;
def FSQRT64 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a),
"sqrt.rn.f64\t$d, $a",
[(set RegF64:$d, (fsqrt RegF64:$a))]>;
// SIN
def FSINrr32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, RegF32:$a),
"sin$r.f32\t$d, $a", []>;
def FSINri32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, f32imm:$a),
"sin$r.f32\t$d, $a", []>;
def FSINrr64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, RegF64:$a),
"sin$r.f64\t$d, $a", []>;
def FSINri64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, f64imm:$a),
"sin$r.f64\t$d, $a", []>;
def FSIN32 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a),
"sin.approx.f32\t$d, $a",
[(set RegF32:$d, (fsin RegF32:$a))]>;
// COS
def FCOSrr32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, RegF32:$a),
"cos$r.f32\t$d, $a", []>;
def FCOSri32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, f32imm:$a),
"cos$r.f32\t$d, $a", []>;
def FCOSrr64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, RegF64:$a),
"cos$r.f64\t$d, $a", []>;
def FCOSri64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, f64imm:$a),
"cos$r.f64\t$d, $a", []>;
def FSIN64 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a),
"sin.approx.f64\t$d, $a",
[(set RegF64:$d, (fsin RegF64:$a))]>;
def FCOS32 : InstPTX<(outs RegF32:$d),
(ins RegF32:$a),
"cos.approx.f32\t$d, $a",
[(set RegF32:$d, (fcos RegF32:$a))]>;
def FCOS64 : InstPTX<(outs RegF64:$d),
(ins RegF64:$a),
"cos.approx.f64\t$d, $a",
[(set RegF64:$d, (fcos RegF64:$a))]>;
///===- Comparison and Selection Instructions -----------------------------===//
@ -641,9 +580,9 @@ defm SELPf64 : PTX_SELP<RegF64, "f64", f64imm, fpimm>;
///===- Logic and Shift Instructions --------------------------------------===//
defm SHL : INT3ntnc<"shl.b", PTXshl>;
defm SRL : INT3ntnc<"shr.u", PTXsrl>;
defm SRA : INT3ntnc<"shr.s", PTXsra>;
defm SHL : PTX_INT3ntnc<"shl.b", PTXshl>;
defm SRL : PTX_INT3ntnc<"shr.u", PTXsrl>;
defm SRA : PTX_INT3ntnc<"shr.s", PTXsra>;
defm AND : PTX_LOGIC<"and", and>;
defm OR : PTX_LOGIC<"or", or>;
@ -798,6 +737,136 @@ def CVTf64s64
def CVTf64f32
: InstPTX<(outs RegF64:$d), (ins RegF32:$a), "cvt.f64.f32\t$d, $a", []>;
///===- Control Flow Instructions -----------------------------------------===//
let isBranch = 1, isTerminator = 1, isBarrier = 1 in {
def BRAd
: InstPTX<(outs), (ins brtarget:$d), "bra\t$d", [(br bb:$d)]>;
}
let isBranch = 1, isTerminator = 1 in {
// FIXME: The pattern part is blank because I cannot (or do not yet know
// how to) use the first operand of PredicateOperand (a RegPred register) here
def BRAdp
: InstPTX<(outs), (ins brtarget:$d), "bra\t$d",
[/*(brcond pred:$_p, bb:$d)*/]>;
}
let isReturn = 1, isTerminator = 1, isBarrier = 1 in {
def EXIT : InstPTX<(outs), (ins), "exit", [(PTXexit)]>;
def RET : InstPTX<(outs), (ins), "ret", [(PTXret)]>;
}
let hasSideEffects = 1 in {
def CALL : InstPTX<(outs), (ins), "call", [(PTXcall)]>;
}
///===- Parameter Passing Pseudo-Instructions -----------------------------===//
def READPARAMPRED : InstPTX<(outs RegPred:$a), (ins i32imm:$b),
"mov.pred\t$a, %param$b", []>;
def READPARAMI16 : InstPTX<(outs RegI16:$a), (ins i32imm:$b),
"mov.b16\t$a, %param$b", []>;
def READPARAMI32 : InstPTX<(outs RegI32:$a), (ins i32imm:$b),
"mov.b32\t$a, %param$b", []>;
def READPARAMI64 : InstPTX<(outs RegI64:$a), (ins i32imm:$b),
"mov.b64\t$a, %param$b", []>;
def READPARAMF32 : InstPTX<(outs RegF32:$a), (ins i32imm:$b),
"mov.f32\t$a, %param$b", []>;
def READPARAMF64 : InstPTX<(outs RegF64:$a), (ins i32imm:$b),
"mov.f64\t$a, %param$b", []>;
def WRITEPARAMPRED : InstPTX<(outs), (ins RegPred:$a), "//w", []>;
def WRITEPARAMI16 : InstPTX<(outs), (ins RegI16:$a), "//w", []>;
def WRITEPARAMI32 : InstPTX<(outs), (ins RegI32:$a), "//w", []>;
def WRITEPARAMI64 : InstPTX<(outs), (ins RegI64:$a), "//w", []>;
def WRITEPARAMF32 : InstPTX<(outs), (ins RegF32:$a), "//w", []>;
def WRITEPARAMF64 : InstPTX<(outs), (ins RegF64:$a), "//w", []>;
//===----------------------------------------------------------------------===//
// Instruction Selection Patterns
//===----------------------------------------------------------------------===//
// FADD
def : Pat<(f32 (fadd RegF32:$a, RegF32:$b)),
(FADDrr32 RndDefault, RegF32:$a, RegF32:$b)>;
def : Pat<(f32 (fadd RegF32:$a, fpimm:$b)),
(FADDri32 RndDefault, RegF32:$a, fpimm:$b)>;
def : Pat<(f64 (fadd RegF64:$a, RegF64:$b)),
(FADDrr64 RndDefault, RegF64:$a, RegF64:$b)>;
def : Pat<(f64 (fadd RegF64:$a, fpimm:$b)),
(FADDri64 RndDefault, RegF64:$a, fpimm:$b)>;
// FSUB
def : Pat<(f32 (fsub RegF32:$a, RegF32:$b)),
(FSUBrr32 RndDefault, RegF32:$a, RegF32:$b)>;
def : Pat<(f32 (fsub RegF32:$a, fpimm:$b)),
(FSUBri32 RndDefault, RegF32:$a, fpimm:$b)>;
def : Pat<(f64 (fsub RegF64:$a, RegF64:$b)),
(FSUBrr64 RndDefault, RegF64:$a, RegF64:$b)>;
def : Pat<(f64 (fsub RegF64:$a, fpimm:$b)),
(FSUBri64 RndDefault, RegF64:$a, fpimm:$b)>;
// FMUL
def : Pat<(f32 (fmul RegF32:$a, RegF32:$b)),
(FMULrr32 RndDefault, RegF32:$a, RegF32:$b)>;
def : Pat<(f32 (fmul RegF32:$a, fpimm:$b)),
(FMULri32 RndDefault, RegF32:$a, fpimm:$b)>;
def : Pat<(f64 (fmul RegF64:$a, RegF64:$b)),
(FMULrr64 RndDefault, RegF64:$a, RegF64:$b)>;
def : Pat<(f64 (fmul RegF64:$a, fpimm:$b)),
(FMULri64 RndDefault, RegF64:$a, fpimm:$b)>;
// FDIV
def : Pat<(f32 (fdiv RegF32:$a, RegF32:$b)),
(FDIVrr32 RndDefault, RegF32:$a, RegF32:$b)>;
def : Pat<(f32 (fdiv RegF32:$a, fpimm:$b)),
(FDIVri32 RndDefault, RegF32:$a, fpimm:$b)>;
def : Pat<(f64 (fdiv RegF64:$a, RegF64:$b)),
(FDIVrr64 RndDefault, RegF64:$a, RegF64:$b)>;
def : Pat<(f64 (fdiv RegF64:$a, fpimm:$b)),
(FDIVri64 RndDefault, RegF64:$a, fpimm:$b)>;
// FMUL+FADD
def : Pat<(f32 (fadd (fmul RegF32:$a, RegF32:$b), RegF32:$c)),
(FMADrrr32 RndDefault, RegF32:$a, RegF32:$b, RegF32:$c)>;
def : Pat<(f32 (fadd (fmul RegF32:$a, RegF32:$b), fpimm:$c)),
(FMADrri32 RndDefault, RegF32:$a, RegF32:$b, fpimm:$c)>;
def : Pat<(f32 (fadd (fmul RegF32:$a, fpimm:$b), fpimm:$c)),
(FMADrrr32 RndDefault, RegF32:$a, fpimm:$b, fpimm:$c)>;
def : Pat<(f32 (fadd (fmul RegF32:$a, RegF32:$b), fpimm:$c)),
(FMADrri32 RndDefault, RegF32:$a, RegF32:$b, fpimm:$c)>;
def : Pat<(f64 (fadd (fmul RegF64:$a, RegF64:$b), RegF64:$c)),
(FMADrrr64 RndDefault, RegF64:$a, RegF64:$b, RegF64:$c)>;
def : Pat<(f64 (fadd (fmul RegF64:$a, RegF64:$b), fpimm:$c)),
(FMADrri64 RndDefault, RegF64:$a, RegF64:$b, fpimm:$c)>;
def : Pat<(f64 (fadd (fmul RegF64:$a, fpimm:$b), fpimm:$c)),
(FMADrri64 RndDefault, RegF64:$a, fpimm:$b, fpimm:$c)>;
// FNEG
def : Pat<(f32 (fneg RegF32:$a)), (FNEGrr32 RndDefault, RegF32:$a)>;
def : Pat<(f32 (fneg fpimm:$a)), (FNEGri32 RndDefault, fpimm:$a)>;
def : Pat<(f64 (fneg RegF64:$a)), (FNEGrr64 RndDefault, RegF64:$a)>;
def : Pat<(f64 (fneg fpimm:$a)), (FNEGri64 RndDefault, fpimm:$a)>;
// FSQRT
def : Pat<(f32 (fsqrt RegF32:$a)), (FSQRTrr32 RndDefault, RegF32:$a)>;
def : Pat<(f32 (fsqrt fpimm:$a)), (FSQRTri32 RndDefault, fpimm:$a)>;
def : Pat<(f64 (fsqrt RegF64:$a)), (FSQRTrr64 RndDefault, RegF64:$a)>;
def : Pat<(f64 (fsqrt fpimm:$a)), (FSQRTri64 RndDefault, fpimm:$a)>;
// FSIN
def : Pat<(f32 (fsin RegF32:$a)), (FSINrr32 RndDefault, RegF32:$a)>;
def : Pat<(f32 (fsin fpimm:$a)), (FSINri32 RndDefault, fpimm:$a)>;
def : Pat<(f64 (fsin RegF64:$a)), (FSINrr64 RndDefault, RegF64:$a)>;
def : Pat<(f64 (fsin fpimm:$a)), (FSINri64 RndDefault, fpimm:$a)>;
// FCOS
def : Pat<(f32 (fcos RegF32:$a)), (FCOSrr32 RndDefault, RegF32:$a)>;
def : Pat<(f32 (fcos fpimm:$a)), (FCOSri32 RndDefault, fpimm:$a)>;
def : Pat<(f64 (fcos RegF64:$a)), (FCOSrr64 RndDefault, RegF64:$a)>;
def : Pat<(f64 (fcos fpimm:$a)), (FCOSri64 RndDefault, fpimm:$a)>;
// Type conversion notes:
// - PTX does not directly support converting a predicate to a value, so we
@ -881,52 +950,6 @@ def : Pat<(f64 (fextend RegF32:$a)), (CVTf64f32 RegF32:$a)>;
def : Pat<(f64 (bitconvert RegI64:$a)), (MOVf64i64 RegI64:$a)>;
///===- Control Flow Instructions -----------------------------------------===//
let isBranch = 1, isTerminator = 1, isBarrier = 1 in {
def BRAd
: InstPTX<(outs), (ins brtarget:$d), "bra\t$d", [(br bb:$d)]>;
}
let isBranch = 1, isTerminator = 1 in {
// FIXME: The pattern part is blank because I cannot (or do not yet know
// how to) use the first operand of PredicateOperand (a RegPred register) here
def BRAdp
: InstPTX<(outs), (ins brtarget:$d), "bra\t$d",
[/*(brcond pred:$_p, bb:$d)*/]>;
}
let isReturn = 1, isTerminator = 1, isBarrier = 1 in {
def EXIT : InstPTX<(outs), (ins), "exit", [(PTXexit)]>;
def RET : InstPTX<(outs), (ins), "ret", [(PTXret)]>;
}
let hasSideEffects = 1 in {
def CALL : InstPTX<(outs), (ins), "call", [(PTXcall)]>;
}
///===- Parameter Passing Pseudo-Instructions -----------------------------===//
def READPARAMPRED : InstPTX<(outs RegPred:$a), (ins i32imm:$b),
"mov.pred\t$a, %param$b", []>;
def READPARAMI16 : InstPTX<(outs RegI16:$a), (ins i32imm:$b),
"mov.b16\t$a, %param$b", []>;
def READPARAMI32 : InstPTX<(outs RegI32:$a), (ins i32imm:$b),
"mov.b32\t$a, %param$b", []>;
def READPARAMI64 : InstPTX<(outs RegI64:$a), (ins i32imm:$b),
"mov.b64\t$a, %param$b", []>;
def READPARAMF32 : InstPTX<(outs RegF32:$a), (ins i32imm:$b),
"mov.f32\t$a, %param$b", []>;
def READPARAMF64 : InstPTX<(outs RegF64:$a), (ins i32imm:$b),
"mov.f64\t$a, %param$b", []>;
def WRITEPARAMPRED : InstPTX<(outs), (ins RegPred:$a), "//w", []>;
def WRITEPARAMI16 : InstPTX<(outs), (ins RegI16:$a), "//w", []>;
def WRITEPARAMI32 : InstPTX<(outs), (ins RegI32:$a), "//w", []>;
def WRITEPARAMI64 : InstPTX<(outs), (ins RegI64:$a), "//w", []>;
def WRITEPARAMF32 : InstPTX<(outs), (ins RegF32:$a), "//w", []>;
def WRITEPARAMF64 : InstPTX<(outs), (ins RegF64:$a), "//w", []>;
///===- Intrinsic Instructions --------------------------------------------===//
include "PTXIntrinsicInstrInfo.td"

View File

@ -367,6 +367,7 @@ bool PTXTargetMachine::addCommonCodeGenPasses(PassManagerBase &PM,
printNoVerify(PM, "After PreEmit passes");
PM.add(createPTXMFInfoExtract(*this, OptLevel));
PM.add(createPTXFPRoundingModePass(*this, OptLevel));
return false;
}