diff --git a/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 5085380cbb6..05900d006f0 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -561,6 +561,7 @@ private: void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi); + void SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_InregOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_BITCAST(SDNode *N, SDValue &Lo, SDValue &Hi); diff --git a/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index ad31f7e94fd..a4e2fc472ed 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -521,7 +521,6 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) { SplitVecRes_VECTOR_SHUFFLE(cast(N), Lo, Hi); break; - case ISD::ANY_EXTEND: case ISD::CONVERT_RNDSAT: case ISD::CTLZ: case ISD::CTTZ: @@ -548,14 +547,18 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) { case ISD::FSIN: case ISD::FSQRT: case ISD::FTRUNC: - case ISD::SIGN_EXTEND: case ISD::SINT_TO_FP: case ISD::TRUNCATE: case ISD::UINT_TO_FP: - case ISD::ZERO_EXTEND: SplitVecRes_UnaryOp(N, Lo, Hi); break; + case ISD::ANY_EXTEND: + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + SplitVecRes_ExtendOp(N, Lo, Hi); + break; + case ISD::ADD: case ISD::SUB: case ISD::MUL: @@ -921,6 +924,62 @@ void DAGTypeLegalizer::SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, } } +void DAGTypeLegalizer::SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, + SDValue &Hi) { + SDLoc dl(N); + EVT SrcVT = N->getOperand(0).getValueType(); + EVT DestVT = N->getValueType(0); + EVT LoVT, HiVT; + GetSplitDestVTs(DestVT, LoVT, HiVT); + + // We can do better than a generic split operation if the extend is doing + // more than just doubling the width of the elements and the following are + // true: + // - The number of vector elements is even, + // - the source type is legal, + // - the type of a split source is illegal, + // - the type of an extended (by doubling element size) source is legal, and + // - the type of that extended source when split is legal. + // + // This won't necessarily completely legalize the operation, but it will + // more effectively move in the right direction and prevent falling down + // to scalarization in many cases due to the input vector being split too + // far. + unsigned NumElements = SrcVT.getVectorNumElements(); + if ((NumElements & 1) == 0 && + SrcVT.getSizeInBits() * 2 < DestVT.getSizeInBits()) { + LLVMContext &Ctx = *DAG.getContext(); + EVT NewSrcVT = EVT::getVectorVT( + Ctx, EVT::getIntegerVT( + Ctx, SrcVT.getVectorElementType().getSizeInBits() * 2), + NumElements); + EVT SplitSrcVT = + EVT::getVectorVT(Ctx, SrcVT.getVectorElementType(), NumElements / 2); + EVT SplitLoVT, SplitHiVT; + GetSplitDestVTs(NewSrcVT, SplitLoVT, SplitHiVT); + if (TLI.isTypeLegal(SrcVT) && !TLI.isTypeLegal(SplitSrcVT) && + TLI.isTypeLegal(NewSrcVT) && TLI.isTypeLegal(SplitLoVT)) { + DEBUG(dbgs() << "Split vector extend via incremental extend:"; + N->dump(&DAG); dbgs() << "\n"); + // Extend the source vector by one step. + SDValue NewSrc = + DAG.getNode(N->getOpcode(), dl, NewSrcVT, N->getOperand(0)); + // Get the low and high halves of the new, extended one step, vector. + Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SplitLoVT, NewSrc, + DAG.getConstant(0, TLI.getVectorIdxTy())); + Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SplitHiVT, NewSrc, + DAG.getConstant(SplitLoVT.getVectorNumElements(), + TLI.getVectorIdxTy())); + // Extend those vector halves the rest of the way. + Lo = DAG.getNode(N->getOpcode(), dl, LoVT, Lo); + Hi = DAG.getNode(N->getOpcode(), dl, HiVT, Hi); + return; + } + } + // Fall back to the generic unary operator splitting otherwise. + SplitVecRes_UnaryOp(N, Lo, Hi); +} + void DAGTypeLegalizer::SplitVecRes_VECTOR_SHUFFLE(ShuffleVectorSDNode *N, SDValue &Lo, SDValue &Hi) { // The low and high parts of the original input give four input vectors. diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp index 5296b3b848d..2de2dfa0ca7 100644 --- a/lib/Target/ARM/ARMISelLowering.cpp +++ b/lib/Target/ARM/ARMISelLowering.cpp @@ -567,16 +567,6 @@ ARMTargetLowering::ARMTargetLowering(TargetMachine &TM) setOperationAction(ISD::FP_ROUND, MVT::v2f32, Expand); setOperationAction(ISD::FP_EXTEND, MVT::v2f64, Expand); - // Custom expand long extensions to vectors. - setOperationAction(ISD::SIGN_EXTEND, MVT::v8i32, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v8i32, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v4i64, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v4i64, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v16i32, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v8i64, Custom); - // NEON does not have single instruction CTPOP for vectors with element // types wider than 8-bits. However, custom lowering can leverage the // v8i8/v16i8 vcnt instruction. @@ -3830,47 +3820,6 @@ SDValue ARMTargetLowering::LowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const { return FrameAddr; } -/// Custom Expand long vector extensions, where size(DestVec) > 2*size(SrcVec), -/// and size(DestVec) > 128-bits. -/// This is achieved by doing the one extension from the SrcVec, splitting the -/// result, extending these parts, and then concatenating these into the -/// destination. -static SDValue ExpandVectorExtension(SDNode *N, SelectionDAG &DAG) { - SDValue Op = N->getOperand(0); - EVT SrcVT = Op.getValueType(); - EVT DestVT = N->getValueType(0); - - assert(DestVT.getSizeInBits() > 128 && - "Custom sext/zext expansion needs >128-bit vector."); - // If this is a normal length extension, use the default expansion. - if (SrcVT.getSizeInBits()*4 != DestVT.getSizeInBits() && - SrcVT.getSizeInBits()*8 != DestVT.getSizeInBits()) - return SDValue(); - - SDLoc dl(N); - unsigned SrcEltSize = SrcVT.getVectorElementType().getSizeInBits(); - unsigned DestEltSize = DestVT.getVectorElementType().getSizeInBits(); - unsigned NumElts = SrcVT.getVectorNumElements(); - LLVMContext &Ctx = *DAG.getContext(); - SDValue Mid, SplitLo, SplitHi, ExtLo, ExtHi; - - EVT MidVT = EVT::getVectorVT(Ctx, EVT::getIntegerVT(Ctx, SrcEltSize*2), - NumElts); - EVT SplitVT = EVT::getVectorVT(Ctx, EVT::getIntegerVT(Ctx, SrcEltSize*2), - NumElts/2); - EVT ExtVT = EVT::getVectorVT(Ctx, EVT::getIntegerVT(Ctx, DestEltSize), - NumElts/2); - - Mid = DAG.getNode(N->getOpcode(), dl, MidVT, Op); - SplitLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SplitVT, Mid, - DAG.getIntPtrConstant(0)); - SplitHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SplitVT, Mid, - DAG.getIntPtrConstant(NumElts/2)); - ExtLo = DAG.getNode(N->getOpcode(), dl, ExtVT, SplitLo); - ExtHi = DAG.getNode(N->getOpcode(), dl, ExtVT, SplitHi); - return DAG.getNode(ISD::CONCAT_VECTORS, dl, DestVT, ExtLo, ExtHi); -} - /// ExpandBITCAST - If the target supports VFP, this function is called to /// expand a bit convert where either the source or destination type is i64 to /// use a VMOVDRR or VMOVRRD node. This should not be done when the non-i64 @@ -6149,10 +6098,6 @@ void ARMTargetLowering::ReplaceNodeResults(SDNode *N, case ISD::BITCAST: Res = ExpandBITCAST(N, DAG); break; - case ISD::SIGN_EXTEND: - case ISD::ZERO_EXTEND: - Res = ExpandVectorExtension(N, DAG); - break; case ISD::SRL: case ISD::SRA: Res = Expand64BitShift(N, DAG, Subtarget); diff --git a/test/CodeGen/X86/long-extend.ll b/test/CodeGen/X86/long-extend.ll new file mode 100644 index 00000000000..5bbd41dad9d --- /dev/null +++ b/test/CodeGen/X86/long-extend.ll @@ -0,0 +1,18 @@ +; RUN: llc < %s -mcpu=core-avx-i -mtriple=x86_64-linux -asm-verbose=0| FileCheck %s +define void @test_long_extend(<16 x i8> %a, <16 x i32>* %p) nounwind { +; CHECK-LABEL: test_long_extend +; CHECK: vpunpcklbw %xmm1, %xmm0, [[REG1:%xmm[0-9]+]] +; CHECK: vpunpckhwd %xmm1, [[REG1]], [[REG2:%xmm[0-9]+]] +; CHECK: vpunpcklwd %xmm1, [[REG1]], %x[[REG3:mm[0-9]+]] +; CHECK: vinsertf128 $1, [[REG2]], %y[[REG3]], [[REG_result0:%ymm[0-9]+]] +; CHECK: vpunpckhbw %xmm1, %xmm0, [[REG4:%xmm[0-9]+]] +; CHECK: vpunpckhwd %xmm1, [[REG4]], [[REG5:%xmm[0-9]+]] +; CHECK: vpunpcklwd %xmm1, [[REG4]], %x[[REG6:mm[0-9]+]] +; CHECK: vinsertf128 $1, [[REG5]], %y[[REG6]], [[REG_result1:%ymm[0-9]+]] +; CHECK: vmovaps [[REG_result1]], 32(%rdi) +; CHECK: vmovaps [[REG_result0]], (%rdi) + + %tmp = zext <16 x i8> %a to <16 x i32> + store <16 x i32> %tmp, <16 x i32>*%p + ret void +}