diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp index bd930d8ef19..5574b214410 100644 --- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1608,23 +1608,40 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, } break; case ISD::SIGN_EXTEND_INREG: { - EVT EVT = cast(Op.getOperand(1))->getVT(); + EVT ExVT = cast(Op.getOperand(1))->getVT(); + + APInt MsbMask = APInt::getHighBitsSet(BitWidth, 1); + // If we only care about the highest bit, don't bother shifting right. + if (MsbMask == DemandedMask) { + unsigned ShAmt = ExVT.getScalarType().getSizeInBits(); + SDValue InOp = Op.getOperand(0); + EVT InVT = Op.getOperand(0).getValueType(); + EVT ShTy = getShiftAmountTy(InVT); + // In this code we may handle vector types. We can't use the + // getShiftAmountTy API because it only works on scalars. + // We use the shift value type because we know that its an integer + // with enough bits. + SDValue ShiftAmt = TLO.DAG.getConstant(BitWidth - ShAmt, + Op.getValueType()); + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, + Op.getValueType(), InOp, ShiftAmt)); + } // Sign extension. Compute the demanded bits in the result that are not // present in the input. APInt NewBits = APInt::getHighBitsSet(BitWidth, - BitWidth - EVT.getScalarType().getSizeInBits()); + BitWidth - ExVT.getScalarType().getSizeInBits()); // If none of the extended bits are demanded, eliminate the sextinreg. if ((NewBits & NewMask) == 0) return TLO.CombineTo(Op, Op.getOperand(0)); APInt InSignBit = - APInt::getSignBit(EVT.getScalarType().getSizeInBits()).zext(BitWidth); + APInt::getSignBit(ExVT.getScalarType().getSizeInBits()).zext(BitWidth); APInt InputDemandedBits = APInt::getLowBitsSet(BitWidth, - EVT.getScalarType().getSizeInBits()) & + ExVT.getScalarType().getSizeInBits()) & NewMask; // Since the sign extended bits are demanded, we know that the sign @@ -1642,7 +1659,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If the input sign bit is known zero, convert this into a zero extension. if (KnownZero.intersects(InSignBit)) return TLO.CombineTo(Op, - TLO.DAG.getZeroExtendInReg(Op.getOperand(0),dl,EVT)); + TLO.DAG.getZeroExtendInReg(Op.getOperand(0),dl,ExVT)); if (KnownOne.intersects(InSignBit)) { // Input sign bit known set KnownOne |= NewBits; diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 68cd44116d4..b5198c8510f 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -12868,6 +12868,7 @@ static SDValue PerformEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG, /// PerformSELECTCombine - Do target-specific dag combines on SELECT and VSELECT /// nodes. static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget *Subtarget) { DebugLoc DL = N->getDebugLoc(); SDValue Cond = N->getOperand(0); @@ -13144,6 +13145,26 @@ static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG, } } + // If we know that this node is legal then we know that it is going to be + // matched by one of the SSE/AVX BLEND instructions. These instructions only + // depend on the highest bit in each word. Try to use SimplifyDemandedBits + // to simplify previous instructions. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (N->getOpcode() == ISD::VSELECT && DCI.isBeforeLegalizeOps() && + !DCI.isBeforeLegalize() && + TLI.isOperationLegal(ISD::VSELECT, VT)) { + unsigned BitWidth = Cond.getValueType().getScalarType().getSizeInBits(); + assert(BitWidth >= 8 && BitWidth <= 64 && "Invalid mask size"); + APInt DemandedMask = APInt::getHighBitsSet(BitWidth, 1); + + APInt KnownZero, KnownOne; + TargetLowering::TargetLoweringOpt TLO(DAG, DCI.isBeforeLegalize(), + DCI.isBeforeLegalizeOps()); + if (TLO.ShrinkDemandedConstant(Cond, DemandedMask) || + TLI.SimplifyDemandedBits(Cond, DemandedMask, KnownZero, KnownOne, TLO)) + DCI.CommitTargetLoweringOpt(TLO); + } + return SDValue(); } @@ -14609,7 +14630,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::EXTRACT_VECTOR_ELT: return PerformEXTRACT_VECTOR_ELTCombine(N, DAG, *this); case ISD::VSELECT: - case ISD::SELECT: return PerformSELECTCombine(N, DAG, Subtarget); + case ISD::SELECT: return PerformSELECTCombine(N, DAG, DCI, Subtarget); case X86ISD::CMOV: return PerformCMOVCombine(N, DAG, DCI); case ISD::ADD: return PerformAddCombine(N, DAG, Subtarget); case ISD::SUB: return PerformSubCombine(N, DAG, Subtarget); diff --git a/test/CodeGen/X86/blend-msb.ll b/test/CodeGen/X86/blend-msb.ll new file mode 100644 index 00000000000..3a10c70ada8 --- /dev/null +++ b/test/CodeGen/X86/blend-msb.ll @@ -0,0 +1,37 @@ +; RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=corei7 -promote-elements -mattr=+sse41 | FileCheck %s + + +; In this test we check that sign-extend of the mask bit is performed by +; shifting the needed bit to the MSB, and not using shl+sra. + +;CHECK: vsel_float +;CHECK: pslld +;CHECK-NEXT: blendvps +;CHECK: ret +define <4 x float> @vsel_float(<4 x float> %v1, <4 x float> %v2) { + %vsel = select <4 x i1> , <4 x float> %v1, <4 x float> %v2 + ret <4 x float> %vsel +} + +;CHECK: vsel_4xi8 +;CHECK: pslld +;CHECK-NEXT: blendvps +;CHECK: ret +define <4 x i8> @vsel_4xi8(<4 x i8> %v1, <4 x i8> %v2) { + %vsel = select <4 x i1> , <4 x i8> %v1, <4 x i8> %v2 + ret <4 x i8> %vsel +} + + +; We do not have native support for v8i16 blends and we have to use the +; blendvb instruction or a sequence of NAND/OR/AND. Make sure that we do not r +; reduce the mask in this case. +;CHECK: vsel_8xi16 +;CHECK: psllw +;CHECK: psraw +;CHECK: pblendvb +;CHECK: ret +define <8 x i16> @vsel_8xi16(<8 x i16> %v1, <8 x i16> %v2) { + %vsel = select <8 x i1> , <8 x i16> %v1, <8 x i16> %v2 + ret <8 x i16> %vsel +}