From 4105fd49d4643f06d746cfd3154b1c0bb8672483 Mon Sep 17 00:00:00 2001 From: Elena Demikhovsky Date: Wed, 25 Feb 2015 09:46:31 +0000 Subject: [PATCH] AVX-512: Gather and Scatter patterns Gather and scatter instructions additionally write to one of the source operands - mask register. In this case Gather has 2 destination values - the loaded value and the mask. Till now we did not support code gen pattern for gather - the instruction was generated from intrinsic only and machine node was hardcoded. When we introduce the masked_gather node, we need to select instruction automatically, in the standard way. I added a flag "hasTwoExplicitDefs" that allows to handle 2 destination operands. (Some code in the X86InstrFragmentsSIMD.td is commented out, just to split one big patch in many small patches) git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@230471 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/CodeGen/ISDOpcodes.h | 2 + include/llvm/Target/Target.td | 3 +- include/llvm/Target/TargetSelectionDAG.td | 12 +++ lib/Target/X86/X86InstrAVX512.td | 95 ++++++++++++----------- lib/Target/X86/X86InstrFragmentsSIMD.td | 54 +++++++++++++ lib/Target/X86/X86InstrInfo.td | 3 + utils/TableGen/CodeGenDAGPatterns.cpp | 20 +++-- utils/TableGen/CodeGenInstruction.cpp | 1 + utils/TableGen/CodeGenInstruction.h | 1 + 9 files changed, 141 insertions(+), 50 deletions(-) diff --git a/include/llvm/CodeGen/ISDOpcodes.h b/include/llvm/CodeGen/ISDOpcodes.h index 2d1c8cd6fdd..8a31f7ed871 100644 --- a/include/llvm/CodeGen/ISDOpcodes.h +++ b/include/llvm/CodeGen/ISDOpcodes.h @@ -689,6 +689,8 @@ namespace ISD { // Masked load and store MLOAD, MSTORE, + // Masked gather and scatter + MGATHER, MSCATTER, /// This corresponds to the llvm.lifetime.* intrinsics. The first operand /// is the chain and the second operand is the alloca pointer. diff --git a/include/llvm/Target/Target.td b/include/llvm/Target/Target.td index 6c970d0c19d..3e65a5d7956 100644 --- a/include/llvm/Target/Target.td +++ b/include/llvm/Target/Target.td @@ -397,7 +397,8 @@ class Instruction { // captured by any operands of the instruction or other flags. // bit hasSideEffects = ?; - + bit hasTwoExplicitDefs = 0; // Does this instruction have 2 explicit + // destinations? // Is this instruction a "real" instruction (with a distinct machine // encoding), or is it a pseudo instruction used for codegen modeling // purposes. diff --git a/include/llvm/Target/TargetSelectionDAG.td b/include/llvm/Target/TargetSelectionDAG.td index 2ecd900d34b..d297162a7f7 100644 --- a/include/llvm/Target/TargetSelectionDAG.td +++ b/include/llvm/Target/TargetSelectionDAG.td @@ -196,6 +196,14 @@ def SDTMaskedLoad: SDTypeProfile<1, 3, [ // masked load SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameAs<0, 3> ]>; +def SDTMaskedGather: SDTypeProfile<1, 3, [ // masked gather + SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisVec<2> +]>; + +def SDTMaskedScatter: SDTypeProfile<1, 3, [ // masked scatter + SDTCisVec<0>, SDTCisVec<1>, SDTCisSameAs<0, 2> +]>; + def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; @@ -468,6 +476,10 @@ def masked_store : SDNode<"ISD::MSTORE", SDTMaskedStore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; def masked_load : SDNode<"ISD::MLOAD", SDTMaskedLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def masked_scatter : SDNode<"ISD::MSCATTER", SDTMaskedScatter, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; +def masked_gather : SDNode<"ISD::MGATHER", SDTMaskedGather, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; // Do not use ld, st directly. Use load, extload, sextload, zextload, store, // and truncst (see below). diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td index 0b2392b9bce..4923bc5f1dd 100644 --- a/lib/Target/X86/X86InstrAVX512.td +++ b/lib/Target/X86/X86InstrAVX512.td @@ -4919,74 +4919,81 @@ defm VPMOVSXDQZ: avx512_extend<0x25, "vpmovsxdq", VK8WM, VR512, VR256X, X86vsext //===----------------------------------------------------------------------===// // GATHER - SCATTER Operations -multiclass avx512_gather opc, string OpcodeStr, RegisterClass KRC, - RegisterClass RC, X86MemOperand memop> { -let mayLoad = 1, +multiclass avx512_gather opc, string OpcodeStr, X86VectorVTInfo _, + X86MemOperand memop, PatFrag GatherNode> { +let mayLoad = 1, hasTwoExplicitDefs = 1, Constraints = "@earlyclobber $dst, $src1 = $dst, $mask = $mask_wb" in - def rm : AVX5128I, EVEX, EVEX_K; + [(set _.RC:$dst, _.KRCWM:$mask_wb, + (_.VT (GatherNode (_.VT _.RC:$src1), _.KRCWM:$mask, + vectoraddr:$src2)))]>, EVEX, EVEX_K, + EVEX_CD8<_.EltSize, CD8VT1>; } let ExeDomain = SSEPackedDouble in { -defm VGATHERDPDZ : avx512_gather<0x92, "vgatherdpd", VK8WM, VR512, vy64xmem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VGATHERQPDZ : avx512_gather<0x93, "vgatherqpd", VK8WM, VR512, vz64mem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; +defm VGATHERDPDZ : avx512_gather<0x92, "vgatherdpd", v8f64_info, vy64xmem, + mgatherv8i32>, EVEX_V512, VEX_W; +defm VGATHERQPDZ : avx512_gather<0x93, "vgatherqpd", v8f64_info, vz64mem, + mgatherv8i64>, EVEX_V512, VEX_W; } let ExeDomain = SSEPackedSingle in { -defm VGATHERDPSZ : avx512_gather<0x92, "vgatherdps", VK16WM, VR512, vz32mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; -defm VGATHERQPSZ : avx512_gather<0x93, "vgatherqps", VK8WM, VR256X, vz64mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VGATHERDPSZ : avx512_gather<0x92, "vgatherdps", v16f32_info, vz32mem, + mgatherv16i32>, EVEX_V512; +defm VGATHERQPSZ : avx512_gather<0x93, "vgatherqps", v8f32x_info, vz64mem, + mgatherv8i64>, EVEX_V512; } -defm VPGATHERDQZ : avx512_gather<0x90, "vpgatherdq", VK8WM, VR512, vy64xmem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VPGATHERDDZ : avx512_gather<0x90, "vpgatherdd", VK16WM, VR512, vz32mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VPGATHERDQZ : avx512_gather<0x90, "vpgatherdq", v8i64_info, vy64xmem, + mgatherv8i32>, EVEX_V512, VEX_W; +defm VPGATHERDDZ : avx512_gather<0x90, "vpgatherdd", v16i32_info, vz32mem, + mgatherv16i32>, EVEX_V512; -defm VPGATHERQQZ : avx512_gather<0x91, "vpgatherqq", VK8WM, VR512, vz64mem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VPGATHERQDZ : avx512_gather<0x91, "vpgatherqd", VK8WM, VR256X, vz64mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VPGATHERQQZ : avx512_gather<0x91, "vpgatherqq", v8i64_info, vz64mem, + mgatherv8i64>, EVEX_V512, VEX_W; +defm VPGATHERQDZ : avx512_gather<0x91, "vpgatherqd", v8i32x_info, vz64mem, + mgatherv8i64>, EVEX_V512; + +multiclass avx512_scatter opc, string OpcodeStr, X86VectorVTInfo _, + X86MemOperand memop, PatFrag ScatterNode> { -multiclass avx512_scatter opc, string OpcodeStr, RegisterClass KRC, - RegisterClass RC, X86MemOperand memop> { let mayStore = 1, Constraints = "$mask = $mask_wb" in - def mr : AVX5128I, EVEX, EVEX_K; + "\t{$src, ${dst} {${mask}}|${dst} {${mask}}, $src}"), + [(set _.KRCWM:$mask_wb, (ScatterNode (_.VT _.RC:$src), + _.KRCWM:$mask, vectoraddr:$dst))]>, + EVEX, EVEX_K, EVEX_CD8<_.EltSize, CD8VT1>; } let ExeDomain = SSEPackedDouble in { -defm VSCATTERDPDZ : avx512_scatter<0xA2, "vscatterdpd", VK8WM, VR512, vy64xmem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VSCATTERQPDZ : avx512_scatter<0xA3, "vscatterqpd", VK8WM, VR512, vz64mem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; +defm VSCATTERDPDZ : avx512_scatter<0xA2, "vscatterdpd", v8f64_info, vy64xmem, + mscatterv8i32>, EVEX_V512, VEX_W; +defm VSCATTERQPDZ : avx512_scatter<0xA3, "vscatterqpd", v8f64_info, vz64mem, + mscatterv8i64>, EVEX_V512, VEX_W; } let ExeDomain = SSEPackedSingle in { -defm VSCATTERDPSZ : avx512_scatter<0xA2, "vscatterdps", VK16WM, VR512, vz32mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; -defm VSCATTERQPSZ : avx512_scatter<0xA3, "vscatterqps", VK8WM, VR256X, vz64mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VSCATTERDPSZ : avx512_scatter<0xA2, "vscatterdps", v16f32_info, vz32mem, + mscatterv16i32>, EVEX_V512; +defm VSCATTERQPSZ : avx512_scatter<0xA3, "vscatterqps", v8f32x_info, vz64mem, + mscatterv8i64>, EVEX_V512; } -defm VPSCATTERDQZ : avx512_scatter<0xA0, "vpscatterdq", VK8WM, VR512, vy64xmem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VPSCATTERDDZ : avx512_scatter<0xA0, "vpscatterdd", VK16WM, VR512, vz32mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VPSCATTERDQZ : avx512_scatter<0xA0, "vpscatterdq", v8i64_info, vy64xmem, + mscatterv8i32>, EVEX_V512, VEX_W; +defm VPSCATTERDDZ : avx512_scatter<0xA0, "vpscatterdd", v16i32_info, vz32mem, + mscatterv16i32>, EVEX_V512; -defm VPSCATTERQQZ : avx512_scatter<0xA1, "vpscatterqq", VK8WM, VR512, vz64mem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VPSCATTERQDZ : avx512_scatter<0xA1, "vpscatterqd", VK8WM, VR256X, vz64mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VPSCATTERQQZ : avx512_scatter<0xA1, "vpscatterqq", v8i64_info, vz64mem, + mscatterv8i64>, EVEX_V512, VEX_W; +defm VPSCATTERQDZ : avx512_scatter<0xA1, "vpscatterqd", v8i32x_info, vz64mem, + mscatterv8i64>, EVEX_V512; // prefetch multiclass avx512_gather_scatter_prefetch opc, Format F, string OpcodeStr, diff --git a/lib/Target/X86/X86InstrFragmentsSIMD.td b/lib/Target/X86/X86InstrFragmentsSIMD.td index a80843fc48f..bf515a847db 100644 --- a/lib/Target/X86/X86InstrFragmentsSIMD.td +++ b/lib/Target/X86/X86InstrFragmentsSIMD.td @@ -304,6 +304,8 @@ def X86exp2 : SDNode<"X86ISD::EXP2", STDFp1SrcRm>; def X86rsqrt28s : SDNode<"X86ISD::RSQRT28", STDFp2SrcRm>; def X86rcp28s : SDNode<"X86ISD::RCP28", STDFp2SrcRm>; def X86RndScale : SDNode<"X86ISD::RNDSCALE", STDFp3SrcRm>; +def X86mgather : SDNode<"X86ISD::GATHER", SDTypeProfile<1, 3, + [SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>]>>; def SDT_PCMPISTRI : SDTypeProfile<2, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, v16i8>, SDTCisVT<3, v16i8>, @@ -524,6 +526,58 @@ def unalignednontemporalstore : PatFrag<(ops node:$val, node:$ptr), return false; }]>; +def mgatherv8i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_gather node:$src1, node:$src2, node:$src3) , [{ + //if (MaskedGatherSDNode *Mgt = dyn_cast(N)) + // return (Mgt->getIndex().getValueType() == MVT::v8i32 || + // Mgt->getBasePtr().getValueType() == MVT::v8i32); + //return false; + return N != 0; +}]>; + +def mgatherv8i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_gather node:$src1, node:$src2, node:$src3) , [{ + //if (MaskedGatherSDNode *Mgt = dyn_cast(N)) + // return (Mgt->getIndex().getValueType() == MVT::v8i64 || + // Mgt->getBasePtr().getValueType() == MVT::v8i64); + //return false; + return N != 0; +}]>; +def mgatherv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_gather node:$src1, node:$src2, node:$src3) , [{ + //if (MaskedGatherSDNode *Mgt = dyn_cast(N)) + // return (Mgt->getIndex().getValueType() == MVT::v16i32 || + // Mgt->getBasePtr().getValueType() == MVT::v16i32); + //return false; + return N != 0; +}]>; + +def mscatterv8i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_scatter node:$src1, node:$src2, node:$src3) , [{ + //if (MaskedScatterSDNode *Sc = dyn_cast(N)) + // return (Sc->getIndex().getValueType() == MVT::v8i32 || + // Sc->getBasePtr().getValueType() == MVT::v8i32); + //return false; + return N != 0; +}]>; + +def mscatterv8i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_scatter node:$src1, node:$src2, node:$src3) , [{ + //if (MaskedScatterSDNode *Sc = dyn_cast(N)) + // return (Sc->getIndex().getValueType() == MVT::v8i64 || + // Sc->getBasePtr().getValueType() == MVT::v8i64); + //return false; + return N != 0; +}]>; +def mscatterv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_scatter node:$src1, node:$src2, node:$src3) , [{ + //if (MaskedScatterSDNode *Sc = dyn_cast(N)) + // return (Sc->getIndex().getValueType() == MVT::v16i32 || + // Sc->getBasePtr().getValueType() == MVT::v16i32); + //return false; + return N != 0; +}]>; + // 128-bit bitconvert pattern fragments def bc_v4f32 : PatFrag<(ops node:$in), (v4f32 (bitconvert node:$in))>; def bc_v2f64 : PatFrag<(ops node:$in), (v2f64 (bitconvert node:$in))>; diff --git a/lib/Target/X86/X86InstrInfo.td b/lib/Target/X86/X86InstrInfo.td index 7ab8822ff31..9881cafa84d 100644 --- a/lib/Target/X86/X86InstrInfo.td +++ b/lib/Target/X86/X86InstrInfo.td @@ -713,6 +713,9 @@ def tls64addr : ComplexPattern; +def vectoraddr : ComplexPattern; +//def vectoraddr : ComplexPattern; + //===----------------------------------------------------------------------===// // X86 Instruction Predicate Definitions. def HasCMov : Predicate<"Subtarget->hasCMov()">; diff --git a/utils/TableGen/CodeGenDAGPatterns.cpp b/utils/TableGen/CodeGenDAGPatterns.cpp index 20bf29eb859..4e3e588fbad 100644 --- a/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/utils/TableGen/CodeGenDAGPatterns.cpp @@ -1113,6 +1113,8 @@ static unsigned GetNumNodeResults(Record *Operator, CodeGenDAGPatterns &CDP) { // FIXME: Should allow access to all the results here. unsigned NumDefsToAdd = InstInfo.Operands.NumDefs ? 1 : 0; + if (InstInfo.hasTwoExplicitDefs) + ++NumDefsToAdd; // Add on one implicit def if it has a resolvable type. if (InstInfo.HasOneImplicitDefWithKnownVT(CDP.getTargetInfo()) !=MVT::Other) @@ -1609,11 +1611,20 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) { assert(getNumTypes() == 0 && "Set doesn't produce a value"); assert(getNumChildren() >= 2 && "Missing RHS of a set?"); unsigned NC = getNumChildren(); + unsigned NumOfSrcs = NC-1; + // destination TreePatternNode *SetVal = getChild(NC-1); bool MadeChange = SetVal->ApplyTypeConstraints(TP, NotRegisters); - for (unsigned i = 0; i < NC-1; ++i) { + // second explicit destination + if (TP.getRecord()->getValueAsBit("hasTwoExplicitDefs")) { + TreePatternNode *Set2Val = getChild(NC-2); + MadeChange = Set2Val->ApplyTypeConstraints(TP, NotRegisters); + NumOfSrcs --; + } + + for (unsigned i = 0; i < NumOfSrcs; ++i) { TreePatternNode *Child = getChild(i); MadeChange |= Child->ApplyTypeConstraints(TP, NotRegisters); @@ -2856,7 +2867,7 @@ const DAGInstruction &CodeGenDAGPatterns::parseInstructionPattern( // Check that all of the results occur first in the list. std::vector Results; - TreePatternNode *Res0Node = nullptr; + SmallVector ResNode; for (unsigned i = 0; i != NumResults; ++i) { if (i == CGI.Operands.size()) I->error("'" + InstResults.begin()->first + @@ -2868,8 +2879,7 @@ const DAGInstruction &CodeGenDAGPatterns::parseInstructionPattern( if (!RNode) I->error("Operand $" + OpName + " does not exist in operand list!"); - if (i == 0) - Res0Node = RNode; + ResNode.push_back(RNode); Record *R = cast(RNode->getLeafValue())->getDef(); if (!R) I->error("Operand $" + OpName + " should be a set destination: all " @@ -2946,7 +2956,7 @@ const DAGInstruction &CodeGenDAGPatterns::parseInstructionPattern( GetNumNodeResults(I->getRecord(), *this)); // Copy fully inferred output node type to instruction result pattern. for (unsigned i = 0; i != NumResults; ++i) - ResultPattern->setType(i, Res0Node->getExtType(i)); + ResultPattern->setType(i, ResNode[i]->getExtType(0)); // Create and insert the instruction. // FIXME: InstImpResults should not be part of DAGInstruction. diff --git a/utils/TableGen/CodeGenInstruction.cpp b/utils/TableGen/CodeGenInstruction.cpp index 10602964e48..b1e43183634 100644 --- a/utils/TableGen/CodeGenInstruction.cpp +++ b/utils/TableGen/CodeGenInstruction.cpp @@ -320,6 +320,7 @@ CodeGenInstruction::CodeGenInstruction(Record *R) isRegSequence = R->getValueAsBit("isRegSequence"); isExtractSubreg = R->getValueAsBit("isExtractSubreg"); isInsertSubreg = R->getValueAsBit("isInsertSubreg"); + hasTwoExplicitDefs = R->getValueAsBit("hasTwoExplicitDefs"); bool Unset; mayLoad = R->getValueAsBitOrUnset("mayLoad", Unset); diff --git a/utils/TableGen/CodeGenInstruction.h b/utils/TableGen/CodeGenInstruction.h index bdbe546ec97..82b23f465f3 100644 --- a/utils/TableGen/CodeGenInstruction.h +++ b/utils/TableGen/CodeGenInstruction.h @@ -255,6 +255,7 @@ namespace llvm { bool isRegSequence : 1; bool isExtractSubreg : 1; bool isInsertSubreg : 1; + bool hasTwoExplicitDefs : 1; std::string DeprecatedReason; bool HasComplexDeprecationPredicate;