From 6e8b53da17ed93c3de0922e94add5f11f068e184 Mon Sep 17 00:00:00 2001 From: Elena Demikhovsky Date: Thu, 8 Jan 2015 12:29:19 +0000 Subject: [PATCH] Masked Load/Store - fixed a bug in type legalization. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@225441 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../SelectionDAG/LegalizeIntegerTypes.cpp | 63 ++++++++++++++++++- lib/CodeGen/SelectionDAG/LegalizeTypes.h | 2 + .../SelectionDAG/LegalizeVectorTypes.cpp | 45 +++++++++++++ test/CodeGen/X86/masked_memop.ll | 49 +++++++++++++++ 4 files changed, 156 insertions(+), 3 deletions(-) diff --git a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp index 52c2d1be430..82b114b80aa 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -66,6 +66,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) { case ISD::EXTRACT_VECTOR_ELT: Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break; case ISD::LOAD: Res = PromoteIntRes_LOAD(cast(N));break; + case ISD::MLOAD: Res = PromoteIntRes_MLOAD(cast(N));break; case ISD::SELECT: Res = PromoteIntRes_SELECT(N); break; case ISD::VSELECT: Res = PromoteIntRes_VSELECT(N); break; case ISD::SELECT_CC: Res = PromoteIntRes_SELECT_CC(N); break; @@ -454,6 +455,24 @@ SDValue DAGTypeLegalizer::PromoteIntRes_LOAD(LoadSDNode *N) { return Res; } +SDValue DAGTypeLegalizer::PromoteIntRes_MLOAD(MaskedLoadSDNode *N) { + EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); + SDValue ExtSrc0 = GetPromotedInteger(N->getSrc0()); + SDValue ExtMask = PromoteTargetBoolean(N->getMask(), NVT); + SDLoc dl(N); + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(N->getPointerInfo(), + MachineMemOperand::MOLoad, NVT.getStoreSize(), + N->getAlignment(), N->getAAInfo(), N->getRanges()); + + SDValue Res = DAG.getMaskedLoad(NVT, dl, N->getChain(), N->getBasePtr(), + ExtMask, ExtSrc0, MMO); + // Legalized the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); + return Res; +} /// Promote the overflow flag of an overflowing arithmetic node. SDValue DAGTypeLegalizer::PromoteIntRes_Overflow(SDNode *N) { // Simply change the return type of the boolean result. @@ -1098,10 +1117,48 @@ SDValue DAGTypeLegalizer::PromoteIntOp_STORE(StoreSDNode *N, unsigned OpNo){ SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo){ assert(OpNo == 2 && "Only know how to promote the mask!"); - EVT DataVT = N->getOperand(3).getValueType(); - SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SDValue DataOp = N->getData(); + EVT DataVT = DataOp.getValueType(); + SDValue Mask = N->getMask(); + EVT MaskVT = Mask.getValueType(); + SDLoc dl(N); + + if (!TLI.isTypeLegal(DataVT)) { + if (getTypeAction(DataVT) == TargetLowering::TypePromoteInteger) { + DataOp = GetPromotedInteger(DataOp); + Mask = PromoteTargetBoolean(Mask, DataOp.getValueType()); + } + else { + assert(getTypeAction(DataVT) == TargetLowering::TypeWidenVector && + "Unexpected data legalization in MSTORE"); + DataOp = GetWidenedVector(DataOp); + + if (getTypeAction(MaskVT) == TargetLowering::TypeWidenVector) + Mask = GetWidenedVector(Mask); + else { + EVT BoolVT = getSetCCResultType(DataOp.getValueType()); + + // We can't use ModifyToType() because we should fill the mask with + // zeroes + unsigned WidenNumElts = BoolVT.getVectorNumElements(); + unsigned MaskNumElts = MaskVT.getVectorNumElements(); + + unsigned NumConcat = WidenNumElts / MaskNumElts; + SmallVector Ops(NumConcat); + SDValue ZeroVal = DAG.getConstant(0, MaskVT); + Ops[0] = Mask; + for (unsigned i = 1; i != NumConcat; ++i) + Ops[i] = ZeroVal; + + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, BoolVT, Ops); + } + } + } + else + Mask = PromoteTargetBoolean(N->getMask(), DataOp.getValueType()); SmallVector NewOps(N->op_begin(), N->op_end()); - NewOps[OpNo] = Mask; + NewOps[2] = Mask; + NewOps[3] = DataOp; return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); } diff --git a/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 805b0fc0463..1cd9f407bca 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -240,6 +240,7 @@ private: SDValue PromoteIntRes_FP_TO_FP16(SDNode *N); SDValue PromoteIntRes_INT_EXTEND(SDNode *N); SDValue PromoteIntRes_LOAD(LoadSDNode *N); + SDValue PromoteIntRes_MLOAD(MaskedLoadSDNode *N); SDValue PromoteIntRes_Overflow(SDNode *N); SDValue PromoteIntRes_SADDSUBO(SDNode *N, unsigned ResNo); SDValue PromoteIntRes_SDIV(SDNode *N); @@ -631,6 +632,7 @@ private: SDValue WidenVecRes_EXTRACT_SUBVECTOR(SDNode* N); SDValue WidenVecRes_INSERT_VECTOR_ELT(SDNode* N); SDValue WidenVecRes_LOAD(SDNode* N); + SDValue WidenVecRes_MLOAD(MaskedLoadSDNode* N); SDValue WidenVecRes_SCALAR_TO_VECTOR(SDNode* N); SDValue WidenVecRes_SIGN_EXTEND_INREG(SDNode* N); SDValue WidenVecRes_SELECT(SDNode* N); diff --git a/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 88f67370228..96b69eec335 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1713,6 +1713,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) { case ISD::VECTOR_SHUFFLE: Res = WidenVecRes_VECTOR_SHUFFLE(cast(N)); break; + case ISD::MLOAD: + Res = WidenVecRes_MLOAD(cast(N)); + break; case ISD::ADD: case ISD::AND: @@ -2403,6 +2406,48 @@ SDValue DAGTypeLegalizer::WidenVecRes_LOAD(SDNode *N) { return Result; } +SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) { + + EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0)); + SDValue Mask = N->getMask(); + EVT MaskVT = Mask.getValueType(); + SDValue Src0 = GetWidenedVector(N->getSrc0()); + SDLoc dl(N); + + if (getTypeAction(MaskVT) == TargetLowering::TypeWidenVector) + Mask = GetWidenedVector(Mask); + else { + EVT BoolVT = getSetCCResultType(WidenVT); + + // We can't use ModifyToType() because we should fill the mask with + // zeroes + unsigned WidenNumElts = BoolVT.getVectorNumElements(); + unsigned MaskNumElts = MaskVT.getVectorNumElements(); + + unsigned NumConcat = WidenNumElts / MaskNumElts; + SmallVector Ops(NumConcat); + SDValue ZeroVal = DAG.getConstant(0, MaskVT); + Ops[0] = Mask; + for (unsigned i = 1; i != NumConcat; ++i) + Ops[i] = ZeroVal; + + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, BoolVT, Ops); + } + + // Rebuild memory operand because MemoryVT was changed + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(N->getPointerInfo(), + MachineMemOperand::MOLoad, WidenVT.getStoreSize(), + N->getAlignment(), N->getAAInfo(), N->getRanges()); + + SDValue Res = DAG.getMaskedLoad(WidenVT, dl, N->getChain(), N->getBasePtr(), + Mask, Src0, MMO); + // Legalized the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); + return Res; +} + SDValue DAGTypeLegalizer::WidenVecRes_SCALAR_TO_VECTOR(SDNode *N) { EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), diff --git a/test/CodeGen/X86/masked_memop.ll b/test/CodeGen/X86/masked_memop.ll index 1a28098bd1c..cce2d909120 100644 --- a/test/CodeGen/X86/masked_memop.ll +++ b/test/CodeGen/X86/masked_memop.ll @@ -149,16 +149,65 @@ define void @test13(<16 x i32> %trigger, <16 x float>* %addr, <16 x float> %val) ret void } +; AVX2-LABEL: test14 +; AVX2: vshufps $-24 +; AVX2: vmaskmovps +define void @test14(<2 x i32> %trigger, <2 x float>* %addr, <2 x float> %val) { + %mask = icmp eq <2 x i32> %trigger, zeroinitializer + call void @llvm.masked.store.v2f32(<2 x float>%val, <2 x float>* %addr, i32 4, <2 x i1>%mask) + ret void +} + +; AVX2-LABEL: test15 +; AVX2: vpmaskmovq +define void @test15(<2 x i32> %trigger, <2 x i32>* %addr, <2 x i32> %val) { + %mask = icmp eq <2 x i32> %trigger, zeroinitializer + call void @llvm.masked.store.v2i32(<2 x i32>%val, <2 x i32>* %addr, i32 4, <2 x i1>%mask) + ret void +} + +; AVX2-LABEL: test16 +; AVX2: vmaskmovps +; AVX2: vblendvps +define <2 x float> @test16(<2 x i32> %trigger, <2 x float>* %addr, <2 x float> %dst) { + %mask = icmp eq <2 x i32> %trigger, zeroinitializer + %res = call <2 x float> @llvm.masked.load.v2f32(<2 x float>* %addr, i32 4, <2 x i1>%mask, <2 x float>%dst) + ret <2 x float> %res +} + +; AVX2-LABEL: test17 +; AVX2: vpmaskmovq +; AVX2: vblendvpd +define <2 x i32> @test17(<2 x i32> %trigger, <2 x i32>* %addr, <2 x i32> %dst) { + %mask = icmp eq <2 x i32> %trigger, zeroinitializer + %res = call <2 x i32> @llvm.masked.load.v2i32(<2 x i32>* %addr, i32 4, <2 x i1>%mask, <2 x i32>%dst) + ret <2 x i32> %res +} + +; AVX2-LABEL: test18 +; AVX2: vmaskmovps +; AVX2-NOT: blend +define <2 x float> @test18(<2 x i32> %trigger, <2 x float>* %addr) { + %mask = icmp eq <2 x i32> %trigger, zeroinitializer + %res = call <2 x float> @llvm.masked.load.v2f32(<2 x float>* %addr, i32 4, <2 x i1>%mask, <2 x float>undef) + ret <2 x float> %res +} + + declare <16 x i32> @llvm.masked.load.v16i32(<16 x i32>*, i32, <16 x i1>, <16 x i32>) declare <4 x i32> @llvm.masked.load.v4i32(<4 x i32>*, i32, <4 x i1>, <4 x i32>) +declare <2 x i32> @llvm.masked.load.v2i32(<2 x i32>*, i32, <2 x i1>, <2 x i32>) declare void @llvm.masked.store.v16i32(<16 x i32>, <16 x i32>*, i32, <16 x i1>) declare void @llvm.masked.store.v8i32(<8 x i32>, <8 x i32>*, i32, <8 x i1>) declare void @llvm.masked.store.v4i32(<4 x i32>, <4 x i32>*, i32, <4 x i1>) +declare void @llvm.masked.store.v2f32(<2 x float>, <2 x float>*, i32, <2 x i1>) +declare void @llvm.masked.store.v2i32(<2 x i32>, <2 x i32>*, i32, <2 x i1>) declare void @llvm.masked.store.v16f32(<16 x float>, <16 x float>*, i32, <16 x i1>) declare void @llvm.masked.store.v16f32p(<16 x float>*, <16 x float>**, i32, <16 x i1>) declare <16 x float> @llvm.masked.load.v16f32(<16 x float>*, i32, <16 x i1>, <16 x float>) declare <8 x float> @llvm.masked.load.v8f32(<8 x float>*, i32, <8 x i1>, <8 x float>) declare <4 x float> @llvm.masked.load.v4f32(<4 x float>*, i32, <4 x i1>, <4 x float>) +declare <2 x float> @llvm.masked.load.v2f32(<2 x float>*, i32, <2 x i1>, <2 x float>) declare <8 x double> @llvm.masked.load.v8f64(<8 x double>*, i32, <8 x i1>, <8 x double>) declare <4 x double> @llvm.masked.load.v4f64(<4 x double>*, i32, <4 x i1>, <4 x double>) declare <2 x double> @llvm.masked.load.v2f64(<2 x double>*, i32, <2 x i1>, <2 x double>)