Masked gather and scatter: Added code for SelectionDAG.

All other patches, including tests will follow.

http://reviews.llvm.org/D7665



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@235970 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Elena Demikhovsky 2015-04-28 07:57:37 +00:00
parent 974d5d32c8
commit 8dfda019a0
6 changed files with 294 additions and 2 deletions

View File

@ -687,9 +687,16 @@ namespace ISD {
ATOMIC_LOAD_UMIN,
ATOMIC_LOAD_UMAX,
// Masked load and store
// Masked load and store - consecutive vector load and store operations
// with additional mask operand that prevents memory accesses to the
// masked-off lanes.
MLOAD, MSTORE,
// Masked gather and scatter - load and store operations for a vector of
// random addresses with additional mask operand that prevents memory
// accesses to the masked-off lanes.
MGATHER, MSCATTER,
/// This corresponds to the llvm.lifetime.* intrinsics. The first operand
/// is the chain and the second operand is the alloca pointer.
LIFETIME_START, LIFETIME_END,

View File

@ -856,6 +856,10 @@ public:
SDValue getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val,
SDValue Ptr, SDValue Mask, EVT MemVT,
MachineMemOperand *MMO, bool IsTrunc);
SDValue getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
SDValue getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
/// Construct a node to track a Value* through the backend.
SDValue getSrcValue(const Value *v);

View File

@ -1151,6 +1151,8 @@ public:
N->getOpcode() == ISD::ATOMIC_STORE ||
N->getOpcode() == ISD::MLOAD ||
N->getOpcode() == ISD::MSTORE ||
N->getOpcode() == ISD::MGATHER ||
N->getOpcode() == ISD::MSCATTER ||
N->isMemIntrinsic() ||
N->isTargetMemoryOpcode();
}
@ -1987,6 +1989,82 @@ public:
}
};
/// This is a base class is used to represent
/// MGATHER and MSCATTER nodes
///
class MaskedGatherScatterSDNode : public MemSDNode {
// Operands
SDUse Ops[5];
public:
friend class SelectionDAG;
MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl,
ArrayRef<SDValue> Operands, SDVTList VTs, EVT MemVT,
MachineMemOperand *MMO)
: MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
assert(Operands.size() == 5 && "Incompatible number of operands");
InitOperands(Ops, Operands.data(), Operands.size());
}
// In the both nodes address is Op1, mask is Op2:
// MaskedGatherSDNode (Chain, src0, mask, base, index), src0 is a passthru value
// MaskedScatterSDNode (Chain, value, mask, base, index)
// Mask is a vector of i1 elements
const SDValue &getBasePtr() const { return getOperand(3); }
const SDValue &getIndex() const { return getOperand(4); }
const SDValue &getMask() const { return getOperand(2); }
const SDValue &getValue() const { return getOperand(1); }
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MGATHER ||
N->getOpcode() == ISD::MSCATTER;
}
};
/// This class is used to represent an MGATHER node
///
class MaskedGatherSDNode : public MaskedGatherScatterSDNode {
public:
friend class SelectionDAG;
MaskedGatherSDNode(unsigned Order, DebugLoc dl, ArrayRef<SDValue> Operands,
SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
: MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, Operands, VTs, MemVT,
MMO) {
assert(getValue().getValueType() == getValueType(0) &&
"Incompatible type of the PathThru value in MaskedGatherSDNode");
assert(getMask().getValueType().getVectorNumElements() ==
getValueType(0).getVectorNumElements() &&
"Vector width mismatch between mask and data");
assert(getMask().getValueType().getScalarType() == MVT::i1 &&
"Vector width mismatch between mask and data");
}
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MGATHER;
}
};
/// This class is used to represent an MSCATTER node
///
class MaskedScatterSDNode : public MaskedGatherScatterSDNode {
public:
friend class SelectionDAG;
MaskedScatterSDNode(unsigned Order, DebugLoc dl,ArrayRef<SDValue> Operands,
SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
: MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, Operands, VTs, MemVT,
MMO) {
assert(getMask().getValueType().getVectorNumElements() ==
getValue().getValueType().getVectorNumElements() &&
"Vector width mismatch between mask and data");
assert(getMask().getValueType().getScalarType() == MVT::i1 &&
"Vector width mismatch between mask and data");
}
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MSCATTER;
}
};
/// An SDNode that represents everything that will be needed
/// to construct a MachineInstr. These nodes are created during the
/// instruction selection proper phase.
@ -2078,7 +2156,7 @@ template <> struct GraphTraits<SDNode*> {
};
/// The largest SDNode class.
typedef AtomicSDNode LargestSDNode;
typedef MaskedGatherScatterSDNode LargestSDNode;
/// The SDNode class with the greatest alignment requirement.
typedef GlobalAddressSDNode MostAlignedSDNode;

View File

@ -5097,6 +5097,55 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val,
return SDValue(N, 0);
}
SDValue
SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl,
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO) {
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(encodeMemSDNodeFlags(ISD::NON_EXTLOAD, ISD::UNINDEXED,
MMO->isVolatile(),
MMO->isNonTemporal(),
MMO->isInvariant()));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr;
if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) {
cast<MaskedGatherSDNode>(E)->refineAlignment(MMO);
return SDValue(E, 0);
}
MaskedGatherSDNode *N =
new (NodeAllocator) MaskedGatherSDNode(dl.getIROrder(), dl.getDebugLoc(),
Ops, VTs, VT, MMO);
CSEMap.InsertNode(N, IP);
InsertNode(N);
return SDValue(N, 0);
}
SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl,
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO) {
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(encodeMemSDNodeFlags(false, ISD::UNINDEXED, MMO->isVolatile(),
MMO->isNonTemporal(),
MMO->isInvariant()));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr;
if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) {
cast<MaskedScatterSDNode>(E)->refineAlignment(MMO);
return SDValue(E, 0);
}
SDNode *N =
new (NodeAllocator) MaskedScatterSDNode(dl.getIROrder(), dl.getDebugLoc(),
Ops, VTs, VT, MMO);
CSEMap.InsertNode(N, IP);
InsertNode(N);
return SDValue(N, 0);
}
SDValue SelectionDAG::getVAArg(EVT VT, SDLoc dl,
SDValue Chain, SDValue Ptr,
SDValue SV,

View File

@ -1059,6 +1059,12 @@ SDValue SelectionDAGBuilder::getValue(const Value *V) {
return Val;
}
// Return true if SDValue exists for the given Value
bool SelectionDAGBuilder::findValue(const Value *V) const {
return (NodeMap.find(V) != NodeMap.end()) ||
(FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end());
}
/// getNonRegisterValue - Return an SDValue for the given Value, but
/// don't look in FuncInfo.ValueMap for a virtual register.
SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) {
@ -3026,6 +3032,92 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) {
setValue(&I, StoreNode);
}
// Gather/scatter receive a vector of pointers.
// This vector of pointers may be represented as a base pointer + vector of
// indices, it depends on GEP and instruction preceeding GEP
// that calculates indices
static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index,
SelectionDAGBuilder* SDB) {
assert (Ptr->getType()->isVectorTy() && "Uexpected pointer type");
GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr);
if (!Gep || Gep->getNumOperands() > 2)
return false;
ShuffleVectorInst *ShuffleInst =
dyn_cast<ShuffleVectorInst>(Gep->getPointerOperand());
if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() ||
cast<Instruction>(ShuffleInst->getOperand(0))->getOpcode() !=
Instruction::InsertElement)
return false;
Ptr = cast<InsertElementInst>(ShuffleInst->getOperand(0))->getOperand(1);
SelectionDAG& DAG = SDB->DAG;
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
// Check is the Ptr is inside current basic block
// If not, look for the shuffle instruction
if (SDB->findValue(Ptr))
Base = SDB->getValue(Ptr);
else if (SDB->findValue(ShuffleInst)) {
SDValue ShuffleNode = SDB->getValue(ShuffleInst);
Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(ShuffleNode),
ShuffleNode.getValueType().getScalarType(), ShuffleNode,
DAG.getConstant(0, TLI.getVectorIdxTy()));
SDB->setValue(Ptr, Base);
}
else
return false;
Value *IndexVal = Gep->getOperand(1);
if (SDB->findValue(IndexVal)) {
Index = SDB->getValue(IndexVal);
if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
IndexVal = Sext->getOperand(0);
if (SDB->findValue(IndexVal))
Index = SDB->getValue(IndexVal);
}
return true;
}
return false;
}
void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
SDLoc sdl = getCurSDLoc();
// llvm.masked.scatter.*(Src0, Ptrs, alignemt, Mask)
Value *Ptr = I.getArgOperand(1);
SDValue Src0 = getValue(I.getArgOperand(0));
SDValue Mask = getValue(I.getArgOperand(3));
EVT VT = Src0.getValueType();
unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(2)))->getZExtValue();
if (!Alignment)
Alignment = DAG.getEVTAlignment(VT);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
AAMDNodes AAInfo;
I.getAAMetadata(AAInfo);
SDValue Base;
SDValue Index;
Value *BasePtr = Ptr;
bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
Value *MemOpBasePtr = UniformBase ? BasePtr : NULL;
MachineMemOperand *MMO = DAG.getMachineFunction().
getMachineMemOperand(MachinePointerInfo(MemOpBasePtr),
MachineMemOperand::MOStore, VT.getStoreSize(),
Alignment, AAInfo);
if (!UniformBase) {
Base = DAG.getTargetConstant(0, TLI.getPointerTy());
Index = getValue(Ptr);
}
SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index };
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO);
DAG.setRoot(Scatter);
setValue(&I, Scatter);
}
void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
SDLoc sdl = getCurSDLoc();
@ -3067,6 +3159,60 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
setValue(&I, Load);
}
void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
SDLoc sdl = getCurSDLoc();
// @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
Value *Ptr = I.getArgOperand(0);
SDValue Src0 = getValue(I.getArgOperand(3));
SDValue Mask = getValue(I.getArgOperand(2));
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
EVT VT = TLI.getValueType(I.getType());
unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(1)))->getZExtValue();
if (!Alignment)
Alignment = DAG.getEVTAlignment(VT);
AAMDNodes AAInfo;
I.getAAMetadata(AAInfo);
const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range);
SDValue Root = DAG.getRoot();
SDValue Base;
SDValue Index;
Value *BasePtr = Ptr;
bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
bool ConstantMemory = false;
if (UniformBase && AA->pointsToConstantMemory(
AliasAnalysis::Location(BasePtr,
AA->getTypeStoreSize(I.getType()),
AAInfo))) {
// Do not serialize (non-volatile) loads of constant memory with anything.
Root = DAG.getEntryNode();
ConstantMemory = true;
}
MachineMemOperand *MMO =
DAG.getMachineFunction().
getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : NULL),
MachineMemOperand::MOLoad, VT.getStoreSize(),
Alignment, AAInfo, Ranges);
if (!UniformBase) {
Base = DAG.getTargetConstant(0, TLI.getPointerTy());
Index = getValue(Ptr);
}
SDValue Ops[] = { Root, Src0, Mask, Base, Index };
SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
Ops, MMO);
SDValue OutChain = Gather.getValue(1);
if (!ConstantMemory)
PendingLoads.push_back(OutChain);
setValue(&I, Gather);
}
void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) {
SDLoc dl = getCurSDLoc();
AtomicOrdering SuccessOrder = I.getSuccessOrdering();
@ -4216,9 +4362,13 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) {
return nullptr;
}
case Intrinsic::masked_gather:
visitMaskedGather(I);
case Intrinsic::masked_load:
visitMaskedLoad(I);
return nullptr;
case Intrinsic::masked_scatter:
visitMaskedScatter(I);
case Intrinsic::masked_store:
visitMaskedStore(I);
return nullptr;

View File

@ -667,6 +667,8 @@ public:
// generate the debug data structures now that we've seen its definition.
void resolveDanglingDebugInfo(const Value *V, SDValue Val);
SDValue getValue(const Value *V);
bool findValue(const Value *V) const;
SDValue getNonRegisterValue(const Value *V);
SDValue getValueImpl(const Value *V);
@ -814,6 +816,8 @@ private:
void visitStore(const StoreInst &I);
void visitMaskedLoad(const CallInst &I);
void visitMaskedStore(const CallInst &I);
void visitMaskedGather(const CallInst &I);
void visitMaskedScatter(const CallInst &I);
void visitAtomicCmpXchg(const AtomicCmpXchgInst &I);
void visitAtomicRMW(const AtomicRMWInst &I);
void visitFence(const FenceInst &I);