diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp index 008edbe4a75..432e3eefb1e 100644 --- a/lib/Target/ARM/ARMISelLowering.cpp +++ b/lib/Target/ARM/ARMISelLowering.cpp @@ -5257,6 +5257,23 @@ static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) { return false; } +static EVT getExtensionTo64Bits(const EVT &OrigVT) { + if (OrigVT.getSizeInBits() >= 64) + return OrigVT; + + assert(OrigVT.isSimple() && "Expecting a simple value type"); + + MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy; + switch (OrigSimpleTy) { + default: llvm_unreachable("Unexpected Vector Type"); + case MVT::v2i8: + case MVT::v2i16: + return MVT::v2i32; + case MVT::v4i8: + return MVT::v4i16; + } +} + /// AddRequiredExtensionForVMULL - Add a sign/zero extension to extend the total /// value size to 64 bits. We need a 64-bit D register as an operand to VMULL. /// We insert the required extension here to get the vector to fill a D register. @@ -5272,18 +5289,8 @@ static SDValue AddRequiredExtensionForVMULL(SDValue N, SelectionDAG &DAG, return N; // Must extend size to at least 64 bits to be used as an operand for VMULL. - MVT::SimpleValueType OrigSimpleTy = OrigTy.getSimpleVT().SimpleTy; - EVT NewVT; - switch (OrigSimpleTy) { - default: llvm_unreachable("Unexpected Orig Vector Type"); - case MVT::v2i8: - case MVT::v2i16: - NewVT = MVT::v2i32; - break; - case MVT::v4i8: - NewVT = MVT::v4i16; - break; - } + EVT NewVT = getExtensionTo64Bits(OrigTy); + return DAG.getNode(ExtOpcode, N->getDebugLoc(), NewVT, N); } @@ -5293,22 +5300,22 @@ static SDValue AddRequiredExtensionForVMULL(SDValue N, SelectionDAG &DAG, /// reach a total size of 64 bits. We have to add the extension separately /// because ARM does not have a sign/zero extending load for vectors. static SDValue SkipLoadExtensionForVMULL(LoadSDNode *LD, SelectionDAG& DAG) { - SDValue NonExtendingLoad = - DAG.getLoad(LD->getMemoryVT(), LD->getDebugLoc(), LD->getChain(), + EVT ExtendedTy = getExtensionTo64Bits(LD->getMemoryVT()); + + // The load already has the right type. + if (ExtendedTy == LD->getMemoryVT()) + return DAG.getLoad(LD->getMemoryVT(), LD->getDebugLoc(), LD->getChain(), LD->getBasePtr(), LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(), LD->isInvariant(), LD->getAlignment()); - unsigned ExtOp = 0; - switch (LD->getExtensionType()) { - default: llvm_unreachable("Unexpected LoadExtType"); - case ISD::EXTLOAD: - case ISD::SEXTLOAD: ExtOp = ISD::SIGN_EXTEND; break; - case ISD::ZEXTLOAD: ExtOp = ISD::ZERO_EXTEND; break; - } - MVT::SimpleValueType MemType = LD->getMemoryVT().getSimpleVT().SimpleTy; - MVT::SimpleValueType ExtType = LD->getValueType(0).getSimpleVT().SimpleTy; - return AddRequiredExtensionForVMULL(NonExtendingLoad, DAG, - MemType, ExtType, ExtOp); + + // We need to create a zextload/sextload. We cannot just create a load + // followed by a zext/zext node because LowerMUL is also run during normal + // operation legalization where we can't create illegal types. + return DAG.getExtLoad(LD->getExtensionType(), LD->getDebugLoc(), ExtendedTy, + LD->getChain(), LD->getBasePtr(), LD->getPointerInfo(), + LD->getMemoryVT(), LD->isVolatile(), + LD->isNonTemporal(), LD->getAlignment()); } /// SkipExtensionForVMULL - For a node that is a SIGN_EXTEND, ZERO_EXTEND, diff --git a/test/CodeGen/ARM/vmul.ll b/test/CodeGen/ARM/vmul.ll index 74628f0c5ce..eb5ad8f0c3d 100644 --- a/test/CodeGen/ARM/vmul.ll +++ b/test/CodeGen/ARM/vmul.ll @@ -599,3 +599,27 @@ for.end179: ; preds = %for.cond.loopexit, declare <8 x i16> @llvm.arm.neon.vrshiftu.v8i16(<8 x i16>, <8 x i16>) nounwind readnone declare <8 x i16> @llvm.arm.neon.vqsubu.v8i16(<8 x i16>, <8 x i16>) nounwind readnone declare <8 x i8> @llvm.arm.neon.vqmovnu.v8i8(<8 x i16>) nounwind readnone + +; vmull lowering would create a zext(v4i8 load()) instead of a zextload(v4i8), +; creating an illegal type during legalization and causing an assert. +; PR15970 +define void @no_illegal_types_vmull_sext(<4 x i32> %a) { +entry: + %wide.load283.i = load <4 x i8>* undef, align 1 + %0 = sext <4 x i8> %wide.load283.i to <4 x i32> + %1 = sub nsw <4 x i32> %0, %a + %2 = mul nsw <4 x i32> %1, %1 + %predphi290.v.i = select <4 x i1> undef, <4 x i32> undef, <4 x i32> %2 + store <4 x i32> %predphi290.v.i, <4 x i32>* undef, align 4 + ret void +} +define void @no_illegal_types_vmull_zext(<4 x i32> %a) { +entry: + %wide.load283.i = load <4 x i8>* undef, align 1 + %0 = zext <4 x i8> %wide.load283.i to <4 x i32> + %1 = sub nsw <4 x i32> %0, %a + %2 = mul nsw <4 x i32> %1, %1 + %predphi290.v.i = select <4 x i1> undef, <4 x i32> undef, <4 x i32> %2 + store <4 x i32> %predphi290.v.i, <4 x i32>* undef, align 4 + ret void +}