diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 6e6285e44fc..410d8b08a8d 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -1365,6 +1365,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SELECT, MVT::v8f64, Custom); setOperationAction(ISD::SELECT, MVT::v8i64, Custom); setOperationAction(ISD::SELECT, MVT::v16f32, Custom); + setOperationAction(ISD::SELECT, MVT::v16i1, Custom); + setOperationAction(ISD::SELECT, MVT::v8i1, Custom); setOperationAction(ISD::ADD, MVT::v8i64, Legal); setOperationAction(ISD::ADD, MVT::v16i32, Legal); @@ -1467,6 +1469,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::CONCAT_VECTORS, MVT::v64i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v32i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64i1, Custom); + setOperationAction(ISD::SELECT, MVT::v32i1, Custom); + setOperationAction(ISD::SELECT, MVT::v64i1, Custom); for (int i = MVT::v32i8; i != MVT::v8i64; ++i) { const MVT VT = (MVT::SimpleValueType)i; @@ -1494,6 +1498,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v8i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom); + setOperationAction(ISD::SELECT, MVT::v4i1, Custom); + setOperationAction(ISD::SELECT, MVT::v2i1, Custom); setOperationAction(ISD::AND, MVT::v8i32, Legal); setOperationAction(ISD::OR, MVT::v8i32, Legal); @@ -13609,6 +13615,17 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { } } + if (VT == MVT::v4i1 || VT == MVT::v2i1) { + SDValue zeroConst = DAG.getIntPtrConstant(0, DL); + Op1 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v8i1, + DAG.getUNDEF(MVT::v8i1), Op1, zeroConst); + Op2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v8i1, + DAG.getUNDEF(MVT::v8i1), Op2, zeroConst); + SDValue newSelect = DAG.getNode(ISD::SELECT, DL, MVT::v8i1, + Cond, Op1, Op2); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, newSelect, zeroConst); + } + if (Cond.getOpcode() == ISD::SETCC) { SDValue NewCond = LowerSETCC(Cond, DAG); if (NewCond.getNode()) @@ -19481,6 +19498,10 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr *MI, case X86::CMOV_RFP32: case X86::CMOV_RFP64: case X86::CMOV_RFP80: + case X86::CMOV_V8I1: + case X86::CMOV_V16I1: + case X86::CMOV_V32I1: + case X86::CMOV_V64I1: return EmitLoweredSelect(MI, BB); case X86::FP32_TO_INT16_IN_MEM: diff --git a/lib/Target/X86/X86InstrCompiler.td b/lib/Target/X86/X86InstrCompiler.td index 6abb035f688..24d5d227f11 100644 --- a/lib/Target/X86/X86InstrCompiler.td +++ b/lib/Target/X86/X86InstrCompiler.td @@ -518,6 +518,10 @@ let usesCustomInserter = 1, Uses = [EFLAGS] in { defm _V8I64 : CMOVrr_PSEUDO; defm _V8F64 : CMOVrr_PSEUDO; defm _V16F32 : CMOVrr_PSEUDO; + defm _V8I1 : CMOVrr_PSEUDO; + defm _V16I1 : CMOVrr_PSEUDO; + defm _V32I1 : CMOVrr_PSEUDO; + defm _V64I1 : CMOVrr_PSEUDO; } // usesCustomInserter = 1, Uses = [EFLAGS] //===----------------------------------------------------------------------===// diff --git a/test/CodeGen/X86/avx512-mask-op.ll b/test/CodeGen/X86/avx512-mask-op.ll index 5f3588d68d0..6e12289389a 100644 --- a/test/CodeGen/X86/avx512-mask-op.ll +++ b/test/CodeGen/X86/avx512-mask-op.ll @@ -207,3 +207,78 @@ true: false: ret void } + +; KNL-LABEL: test8 +; KNL: vpxord %zmm2, %zmm2, %zmm2 +; KNL: jg +; KNL: vpcmpltud %zmm2, %zmm1, %k1 +; KNL: jmp +; KNL: vpcmpgtd %zmm2, %zmm0, %k1 + +; SKX-LABEL: test8 +; SKX: jg +; SKX: vpcmpltud {{.*}}, %k0 +; SKX: vpmovm2b +; SKX: vpcmpgtd {{.*}}, %k0 +; SKX: vpmovm2b + +define <16 x i8> @test8(<16 x i32>%a, <16 x i32>%b, i32 %a1, i32 %b1) { + %cond = icmp sgt i32 %a1, %b1 + %cmp1 = icmp sgt <16 x i32> %a, zeroinitializer + %cmp2 = icmp ult <16 x i32> %b, zeroinitializer + %mix = select i1 %cond, <16 x i1> %cmp1, <16 x i1> %cmp2 + %res = sext <16 x i1> %mix to <16 x i8> + ret <16 x i8> %res +} + +; KNL-LABEL: test9 +; KNL: jg +; KNL: vpmovsxbd %xmm1, %zmm0 +; KNL: jmp +; KNL: vpmovsxbd %xmm0, %zmm0 + +; SKX-LABEL: test9 +; SKX: vpmovb2m %xmm1, %k0 +; SKX: vpmovm2b %k0, %xmm0 +; SKX: retq +; SKX: vpmovb2m %xmm0, %k0 +; SKX: vpmovm2b %k0, %xmm0 + +define <16 x i1> @test9(<16 x i1>%a, <16 x i1>%b, i32 %a1, i32 %b1) { + %mask = icmp sgt i32 %a1, %b1 + %c = select i1 %mask, <16 x i1>%a, <16 x i1>%b + ret <16 x i1>%c +} + +; KNL-LABEL: test10 +; KNL: jg +; KNL: vpmovsxwq %xmm1, %zmm0 +; KNL: jmp +; KNL: vpmovsxwq %xmm0, %zmm0 + +; SKX-LABEL: test10 +; SKX: jg +; SKX: vpmovw2m %xmm1, %k0 +; SKX: vpmovm2w %k0, %xmm0 +; SKX: retq +; SKX: vpmovw2m %xmm0, %k0 +; SKX: vpmovm2w %k0, %xmm0 +define <8 x i1> @test10(<8 x i1>%a, <8 x i1>%b, i32 %a1, i32 %b1) { + %mask = icmp sgt i32 %a1, %b1 + %c = select i1 %mask, <8 x i1>%a, <8 x i1>%b + ret <8 x i1>%c +} + +; SKX-LABEL: test11 +; SKX: jg +; SKX: vpmovd2m %xmm1, %k0 +; SKX: vpmovm2d %k0, %xmm0 +; SKX: retq +; SKX: vpmovd2m %xmm0, %k0 +; SKX: vpmovm2d %k0, %xmm0 +define <4 x i1> @test11(<4 x i1>%a, <4 x i1>%b, i32 %a1, i32 %b1) { + %mask = icmp sgt i32 %a1, %b1 + %c = select i1 %mask, <4 x i1>%a, <4 x i1>%b + ret <4 x i1>%c +} +