From d6fb53adb19ccfbfb1eedec11c899aaa8401d036 Mon Sep 17 00:00:00 2001 From: Nadav Rotem Date: Thu, 27 Dec 2012 08:15:45 +0000 Subject: [PATCH] On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized register. In most cases we actually compare or select YMM-sized registers and mixing the two types creates horrible code. This commit optimizes some of the transition sequences. PR14657. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@171148 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/X86/X86ISelLowering.cpp | 121 +++++++++++++++++++++++++---- test/CodeGen/X86/v8i1-masks.ll | 38 +++++++++ 2 files changed, 144 insertions(+), 15 deletions(-) create mode 100644 test/CodeGen/X86/v8i1-masks.ll diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 7016b4465d2..d3c21bd703f 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -15731,9 +15731,92 @@ static bool CanFoldXORWithAllOnes(const SDNode *N) { return false; } +// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized +// register. In most cases we actually compare or select YMM-sized registers +// and mixing the two types creates horrible code. This method optimizes +// some of the transition sequences. +static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget *Subtarget) { + EVT VT = N->getValueType(0); + if (VT.getSizeInBits() != 256) + return SDValue(); + + assert((N->getOpcode() == ISD::ANY_EXTEND || + N->getOpcode() == ISD::ZERO_EXTEND || + N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node"); + + SDValue Narrow = N->getOperand(0); + EVT NarrowVT = Narrow->getValueType(0); + if (NarrowVT.getSizeInBits() != 128) + return SDValue(); + + if (Narrow->getOpcode() != ISD::XOR && + Narrow->getOpcode() != ISD::AND && + Narrow->getOpcode() != ISD::OR) + return SDValue(); + + SDValue N0 = Narrow->getOperand(0); + SDValue N1 = Narrow->getOperand(1); + DebugLoc DL = Narrow->getDebugLoc(); + + // The Left side has to be a trunc. + if (N0.getOpcode() != ISD::TRUNCATE) + return SDValue(); + + // The type of the truncated inputs. + EVT WideVT = N0->getOperand(0)->getValueType(0); + if (WideVT != VT) + return SDValue(); + + // The right side has to be a 'trunc' or a constant vector. + bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE; + bool RHSConst = (isSplatVector(N1.getNode()) && + isa(N1->getOperand(0))); + if (!RHSTrunc && !RHSConst) + return SDValue(); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), WideVT)) + return SDValue(); + + // Set N0 and N1 to hold the inputs to the new wide operation. + N0 = N0->getOperand(0); + if (RHSConst) { + N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getScalarType(), + N1->getOperand(0)); + SmallVector C(WideVT.getVectorNumElements(), N1); + N1 = DAG.getNode(ISD::BUILD_VECTOR, DL, WideVT, &C[0], C.size()); + } else if (RHSTrunc) { + N1 = N1->getOperand(0); + } + + // Generate the wide operation. + SDValue Op = DAG.getNode(N->getOpcode(), DL, WideVT, N0, N1); + unsigned Opcode = N->getOpcode(); + switch (Opcode) { + case ISD::ANY_EXTEND: + return Op; + case ISD::ZERO_EXTEND: { + unsigned InBits = NarrowVT.getScalarType().getSizeInBits(); + APInt Mask = APInt::getAllOnesValue(InBits); + Mask = Mask.zext(VT.getScalarType().getSizeInBits()); + return DAG.getNode(ISD::AND, DL, VT, + Op, DAG.getConstant(Mask, VT)); + } + case ISD::SIGN_EXTEND: + return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, + Op, DAG.getValueType(NarrowVT)); + default: + llvm_unreachable("Unexpected opcode"); + } +} + static SDValue PerformAndCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget *Subtarget) { + EVT VT = N->getValueType(0); if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -15741,8 +15824,6 @@ static SDValue PerformAndCombine(SDNode *N, SelectionDAG &DAG, if (R.getNode()) return R; - EVT VT = N->getValueType(0); - // Create BLSI, and BLSR instructions // BLSI is X & (-X) // BLSR is X & (X-1) @@ -15803,6 +15884,7 @@ static SDValue PerformAndCombine(SDNode *N, SelectionDAG &DAG, static SDValue PerformOrCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget *Subtarget) { + EVT VT = N->getValueType(0); if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -15810,8 +15892,6 @@ static SDValue PerformOrCombine(SDNode *N, SelectionDAG &DAG, if (R.getNode()) return R; - EVT VT = N->getValueType(0); - SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -15991,6 +16071,7 @@ static SDValue performIntegerAbsCombine(SDNode *N, SelectionDAG &DAG) { static SDValue PerformXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget *Subtarget) { + EVT VT = N->getValueType(0); if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -16004,8 +16085,6 @@ static SDValue PerformXorCombine(SDNode *N, SelectionDAG &DAG, if (!Subtarget->hasBMI()) return SDValue(); - EVT VT = N->getValueType(0); - if (VT != MVT::i32 && VT != MVT::i64) return SDValue(); @@ -16671,6 +16750,12 @@ static SDValue PerformSExtCombine(SDNode *N, SelectionDAG &DAG, EVT OpVT = Op.getValueType(); DebugLoc dl = N->getDebugLoc(); + if (VT.isVector() && VT.getSizeInBits() == 256) { + SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget); + if (R.getNode()) + return R; + } + if ((VT == MVT::v4i64 && OpVT == MVT::v4i32) || (VT == MVT::v8i32 && OpVT == MVT::v8i16)) { @@ -16768,15 +16853,21 @@ static SDValue PerformZExtCombine(SDNode *N, SelectionDAG &DAG, N0.hasOneUse() && N0.getOperand(0).hasOneUse()) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() != X86ISD::SETCC_CARRY) - return SDValue(); - ConstantSDNode *C = dyn_cast(N0.getOperand(1)); - if (!C || C->getZExtValue() != 1) - return SDValue(); - return DAG.getNode(ISD::AND, dl, VT, - DAG.getNode(X86ISD::SETCC_CARRY, dl, VT, - N00.getOperand(0), N00.getOperand(1)), - DAG.getConstant(1, VT)); + if (N00.getOpcode() == X86ISD::SETCC_CARRY) { + ConstantSDNode *C = dyn_cast(N0.getOperand(1)); + if (!C || C->getZExtValue() != 1) + return SDValue(); + return DAG.getNode(ISD::AND, dl, VT, + DAG.getNode(X86ISD::SETCC_CARRY, dl, VT, + N00.getOperand(0), N00.getOperand(1)), + DAG.getConstant(1, VT)); + } + } + + if (VT.isVector() && VT.getSizeInBits() == 256) { + SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget); + if (R.getNode()) + return R; } // Optimize vectors in AVX mode: diff --git a/test/CodeGen/X86/v8i1-masks.ll b/test/CodeGen/X86/v8i1-masks.ll new file mode 100644 index 00000000000..01079997a33 --- /dev/null +++ b/test/CodeGen/X86/v8i1-masks.ll @@ -0,0 +1,38 @@ +; RUN: llc -march=x86-64 -mtriple=x86_64-apple-darwin -mcpu=corei7-avx -o - < %s | FileCheck %s + +;CHECK: and_masks +;CHECK: vmovups +;CHECK-NEXT: vcmpltp +;CHECK-NEXT: vandps +;CHECK-NEXT: vmovups +;CHECK: ret + +define void @and_masks(<8 x float>* %a, <8 x float>* %b, <8 x float>* %c) nounwind uwtable noinline ssp { + %v0 = load <8 x float>* %a, align 16 + %v1 = load <8 x float>* %b, align 16 + %m0 = fcmp olt <8 x float> %v1, %v0 + %v2 = load <8 x float>* %c, align 16 + %m1 = fcmp olt <8 x float> %v2, %v0 + %mand = and <8 x i1> %m1, %m0 + %r = zext <8 x i1> %mand to <8 x i32> + store <8 x i32> %r, <8 x i32>* undef, align 16 + ret void +} + +;CHECK: neg_mask +;CHECK: vmovups +;CHECK-NEXT: vcmpltps +;CHECK-NEXT: vandps +;CHECK-NEXT: vmovups +;CHECK: ret + +define void @neg_masks(<8 x float>* %a, <8 x float>* %b, <8 x float>* %c) nounwind uwtable noinline ssp { + %v0 = load <8 x float>* %a, align 16 + %v1 = load <8 x float>* %b, align 16 + %m0 = fcmp olt <8 x float> %v1, %v0 + %mand = xor <8 x i1> %m0, + %r = zext <8 x i1> %mand to <8 x i32> + store <8 x i32> %r, <8 x i32>* undef, align 16 + ret void +} +