From 78fe9ababead2168f7196c6a47402cf499a0aaf7 Mon Sep 17 00:00:00 2001 From: Evan Cheng Date: Tue, 29 Mar 2011 01:56:09 +0000 Subject: [PATCH] Optimizing (zext A + zext B) * C, to (VMULL A, C) + (VMULL B, C) during isel lowering to fold the zero-extend's and take advantage of no-stall back to back vmul + vmla: vmull q0, d4, d6 vmlal q0, d5, d6 is faster than vaddl q0, d4, d5 vmovl q1, d6 vmul q0, q0, q1 This allows us to vmull + vmlal for: f = vmull_u8( vget_high_u8(s), c); f = vmlal_u8(f, vget_low_u8(s), c); rdar://9197392 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@128444 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/ARM/ARMISelLowering.cpp | 94 +++++++++++++++++++++++++----- test/CodeGen/ARM/vmul.ll | 29 +++++++++ 2 files changed, 109 insertions(+), 14 deletions(-) diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp index 623393df4aa..3267ea88f39 100644 --- a/lib/Target/ARM/ARMISelLowering.cpp +++ b/lib/Target/ARM/ARMISelLowering.cpp @@ -4370,6 +4370,28 @@ static SDValue SkipExtension(SDNode *N, SelectionDAG &DAG) { MVT::getVectorVT(TruncVT, NumElts), Ops.data(), NumElts); } +static bool isAddSubSExt(SDNode *N, SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + if (Opcode == ISD::ADD || Opcode == ISD::SUB) { + SDNode *N0 = N->getOperand(0).getNode(); + SDNode *N1 = N->getOperand(1).getNode(); + return N0->hasOneUse() && N1->hasOneUse() && + isSignExtended(N0, DAG) && isSignExtended(N1, DAG); + } + return false; +} + +static bool isAddSubZExt(SDNode *N, SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + if (Opcode == ISD::ADD || Opcode == ISD::SUB) { + SDNode *N0 = N->getOperand(0).getNode(); + SDNode *N1 = N->getOperand(1).getNode(); + return N0->hasOneUse() && N1->hasOneUse() && + isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG); + } + return false; +} + static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) { // Multiplications are only custom-lowered for 128-bit vectors so that // VMULL can be detected. Otherwise v2i64 multiplications are not legal. @@ -4378,26 +4400,70 @@ static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) { SDNode *N0 = Op.getOperand(0).getNode(); SDNode *N1 = Op.getOperand(1).getNode(); unsigned NewOpc = 0; - if (isSignExtended(N0, DAG) && isSignExtended(N1, DAG)) + bool isMLA = false; + bool isN0SExt = isSignExtended(N0, DAG); + bool isN1SExt = isSignExtended(N1, DAG); + if (isN0SExt && isN1SExt) NewOpc = ARMISD::VMULLs; - else if (isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG)) - NewOpc = ARMISD::VMULLu; - else if (VT == MVT::v2i64) - // Fall through to expand this. It is not legal. - return SDValue(); - else - // Other vector multiplications are legal. - return Op; + else { + bool isN0ZExt = isZeroExtended(N0, DAG); + bool isN1ZExt = isZeroExtended(N1, DAG); + if (isN0ZExt && isN1ZExt) + NewOpc = ARMISD::VMULLu; + else if (isN1SExt || isN1ZExt) { + // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these + // into (s/zext A * s/zext C) + (s/zext B * s/zext C) + if (isN1SExt && isAddSubSExt(N0, DAG)) { + NewOpc = ARMISD::VMULLs; + isMLA = true; + } else if (isN1ZExt && isAddSubZExt(N0, DAG)) { + NewOpc = ARMISD::VMULLu; + isMLA = true; + } else if (isN0ZExt && isAddSubZExt(N1, DAG)) { + std::swap(N0, N1); + NewOpc = ARMISD::VMULLu; + isMLA = true; + } + } + + if (!NewOpc) { + if (VT == MVT::v2i64) + // Fall through to expand this. It is not legal. + return SDValue(); + else + // Other vector multiplications are legal. + return Op; + } + } // Legalize to a VMULL instruction. DebugLoc DL = Op.getDebugLoc(); - SDValue Op0 = SkipExtension(N0, DAG); + SDValue Op0; SDValue Op1 = SkipExtension(N1, DAG); + if (!isMLA) { + Op0 = SkipExtension(N0, DAG); + assert(Op0.getValueType().is64BitVector() && + Op1.getValueType().is64BitVector() && + "unexpected types for extended operands to VMULL"); + return DAG.getNode(NewOpc, DL, VT, Op0, Op1); + } - assert(Op0.getValueType().is64BitVector() && - Op1.getValueType().is64BitVector() && - "unexpected types for extended operands to VMULL"); - return DAG.getNode(NewOpc, DL, VT, Op0, Op1); + // Optimizing (zext A + zext B) * C, to (VMULL A, C) + (VMULL B, C) during + // isel lowering to take advantage of no-stall back to back vmul + vmla. + // vmull q0, d4, d6 + // vmlal q0, d5, d6 + // is faster than + // vaddl q0, d4, d5 + // vmovl q1, d6 + // vmul q0, q0, q1 + SDValue N00 = SkipExtension(N0->getOperand(0).getNode(), DAG); + SDValue N01 = SkipExtension(N0->getOperand(1).getNode(), DAG); + EVT Op1VT = Op1.getValueType(); + return DAG.getNode(N0->getOpcode(), DL, VT, + DAG.getNode(NewOpc, DL, VT, + DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1), + DAG.getNode(NewOpc, DL, VT, + DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)); } static SDValue diff --git a/test/CodeGen/ARM/vmul.ll b/test/CodeGen/ARM/vmul.ll index ee033caa00d..585394ec30a 100644 --- a/test/CodeGen/ARM/vmul.ll +++ b/test/CodeGen/ARM/vmul.ll @@ -339,3 +339,32 @@ define <2 x i64> @vmull_extvec_u32(<2 x i32> %arg) nounwind { %tmp4 = mul <2 x i64> %tmp3, ret <2 x i64> %tmp4 } + +; rdar://9197392 +define void @distribue(i16* %dst, i8* %src, i32 %mul) nounwind { +entry: +; CHECK: distribue: +; CHECK: vmull.u8 [[REG1:(q[0-9]+)]], d{{.*}}, [[REG2:(d[0-9]+)]] +; CHECK: vmlal.u8 [[REG1]], d{{.*}}, [[REG2]] + %0 = trunc i32 %mul to i8 + %1 = insertelement <8 x i8> undef, i8 %0, i32 0 + %2 = shufflevector <8 x i8> %1, <8 x i8> undef, <8 x i32> zeroinitializer + %3 = tail call <16 x i8> @llvm.arm.neon.vld1.v16i8(i8* %src, i32 1) + %4 = bitcast <16 x i8> %3 to <2 x double> + %5 = extractelement <2 x double> %4, i32 1 + %6 = bitcast double %5 to <8 x i8> + %7 = zext <8 x i8> %6 to <8 x i16> + %8 = zext <8 x i8> %2 to <8 x i16> + %9 = extractelement <2 x double> %4, i32 0 + %10 = bitcast double %9 to <8 x i8> + %11 = zext <8 x i8> %10 to <8 x i16> + %12 = add <8 x i16> %7, %11 + %13 = mul <8 x i16> %12, %8 + %14 = bitcast i16* %dst to i8* + tail call void @llvm.arm.neon.vst1.v8i16(i8* %14, <8 x i16> %13, i32 2) + ret void +} + +declare <16 x i8> @llvm.arm.neon.vld1.v16i8(i8*, i32) nounwind readonly + +declare void @llvm.arm.neon.vst1.v8i16(i8*, <8 x i16>, i32) nounwind