From 448a1a07348dad12592cce2fe4d9413024b81308 Mon Sep 17 00:00:00 2001 From: Matt Arsenault Date: Mon, 14 Apr 2014 21:50:37 +0000 Subject: [PATCH] Revert "Revert r206045, "Fix shift by constants for vector."" Fix cases where the Value itself is used, and not the constant value. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@206214 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/InstCombine/InstCombine.h | 2 +- .../InstCombine/InstCombineShifts.cpp | 48 ++++++++---- test/Transforms/InstCombine/pr19420.ll | 67 ++++++++++++++++ test/Transforms/InstCombine/shift.ll | 78 +++++++++++++++++++ 4 files changed, 179 insertions(+), 16 deletions(-) create mode 100644 test/Transforms/InstCombine/pr19420.ll diff --git a/lib/Transforms/InstCombine/InstCombine.h b/lib/Transforms/InstCombine/InstCombine.h index 822e146ac46..4ee2f59c174 100644 --- a/lib/Transforms/InstCombine/InstCombine.h +++ b/lib/Transforms/InstCombine/InstCombine.h @@ -171,7 +171,7 @@ public: ICmpInst::Predicate Pred); Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I); - Instruction *FoldShiftByConstant(Value *Op0, ConstantInt *Op1, + Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I); Instruction *commonCastTransforms(CastInst &CI); Instruction *commonPointerCastTransforms(CastInst &CI); diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index e873c1d31c4..f2e2cb3e66f 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -33,7 +33,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - if (ConstantInt *CUI = dyn_cast(Op1)) + if (Constant *CUI = dyn_cast(Op1)) if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; @@ -309,20 +309,30 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, -Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, +Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { bool isLeftShift = I.getOpcode() == Instruction::Shl; + ConstantInt *COp1 = nullptr; + if (ConstantDataVector *CV = dyn_cast(Op1)) + COp1 = dyn_cast_or_null(CV->getSplatValue()); + else if (ConstantVector *CV = dyn_cast(Op1)) + COp1 = dyn_cast_or_null(CV->getSplatValue()); + else + COp1 = dyn_cast(Op1); + + if (!COp1) + return nullptr; // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && - CanEvaluateShifted(Op0, Op1->getZExtValue(), isLeftShift, *this)) { + CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this)) { DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); return ReplaceInstUsesWith(I, - GetShiftedValue(Op0, Op1->getZExtValue(), isLeftShift, *this)); + GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this)); } @@ -333,7 +343,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // shl i32 X, 32 = 0 and srl i8 Y, 9 = 0, ... just don't eliminate // a signed shift. // - if (Op1->uge(TypeBits)) { + if (COp1->uge(TypeBits)) { if (I.getOpcode() != Instruction::AShr) return ReplaceInstUsesWith(I, Constant::getNullValue(Op0->getType())); // ashr i32 X, 32 --> ashr i32 X, 31 @@ -367,7 +377,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, if (TrOp && I.isLogicalShift() && TrOp->isShift() && isa(TrOp->getOperand(1))) { // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = ConstantExpr::getZExt(Op1, TrOp->getType()); + Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType()); // (shift2 (shift1 & 0x00FF), c2) Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName()); @@ -384,10 +394,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // shift. We know that it is a logical shift by a constant, so adjust the // mask as appropriate. if (I.getOpcode() == Instruction::Shl) - MaskV <<= Op1->getZExtValue(); + MaskV <<= COp1->getZExtValue(); else { assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); - MaskV = MaskV.lshr(Op1->getZExtValue()); + MaskV = MaskV.lshr(COp1->getZExtValue()); } // shift1 & 0x00FF @@ -421,9 +431,13 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // (X + (Y << C)) Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1, Op0BO->getOperand(1)->getName()); - uint32_t Op1Val = Op1->getLimitedValue(TypeBits); - return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(), - APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); + uint32_t Op1Val = COp1->getLimitedValue(TypeBits); + + APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); + Constant *Mask = ConstantInt::get(I.getContext(), Bits); + if (VectorType *VT = dyn_cast(X->getType())) + Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + return BinaryOperator::CreateAnd(X, Mask); } // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) @@ -453,9 +467,13 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // (X + (Y << C)) Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS, Op0BO->getOperand(0)->getName()); - uint32_t Op1Val = Op1->getLimitedValue(TypeBits); - return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(), - APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); + uint32_t Op1Val = COp1->getLimitedValue(TypeBits); + + APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); + Constant *Mask = ConstantInt::get(I.getContext(), Bits); + if (VectorType *VT = dyn_cast(X->getType())) + Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + return BinaryOperator::CreateAnd(X, Mask); } // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) @@ -541,7 +559,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, ConstantInt *ShiftAmt1C = cast(ShiftOp->getOperand(1)); uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); - uint32_t ShiftAmt2 = Op1->getLimitedValue(TypeBits); + uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); if (ShiftAmt1 == 0) return 0; // Will be simplified in the future. Value *X = ShiftOp->getOperand(0); diff --git a/test/Transforms/InstCombine/pr19420.ll b/test/Transforms/InstCombine/pr19420.ll new file mode 100644 index 00000000000..23fa0a40974 --- /dev/null +++ b/test/Transforms/InstCombine/pr19420.ll @@ -0,0 +1,67 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +; CHECK-LABEL: @test_FoldShiftByConstant_CreateSHL +; CHECK: mul <4 x i32> %in, +; CHECK-NEXT: ret +define <4 x i32> @test_FoldShiftByConstant_CreateSHL(<4 x i32> %in) { + %mul.i = mul <4 x i32> %in, + %vshl_n = shl <4 x i32> %mul.i, + ret <4 x i32> %vshl_n +} + +; CHECK-LABEL: @test_FoldShiftByConstant_CreateSHL2 +; CHECK: mul <8 x i16> %in, +; CHECK-NEXT: ret +define <8 x i16> @test_FoldShiftByConstant_CreateSHL2(<8 x i16> %in) { + %mul.i = mul <8 x i16> %in, + %vshl_n = shl <8 x i16> %mul.i, + ret <8 x i16> %vshl_n +} + +; CHECK-LABEL: @test_FoldShiftByConstant_CreateAnd +; CHECK: mul <16 x i8> %in0, +; CHECK-NEXT: and <16 x i8> %vsra_n2, +; CHECK-NEXT: ret +define <16 x i8> @test_FoldShiftByConstant_CreateAnd(<16 x i8> %in0) { + %vsra_n = ashr <16 x i8> %in0, + %tmp = add <16 x i8> %in0, %vsra_n + %vshl_n = shl <16 x i8> %tmp, + ret <16 x i8> %vshl_n +} + + +define i32 @bar(i32 %x, i32 %y) { + %a = lshr i32 %x, 4 + %b = add i32 %a, %y + %c = shl i32 %b, 4 + ret i32 %c +} + +define <2 x i32> @bar_v2i32(<2 x i32> %x, <2 x i32> %y) { + %a = lshr <2 x i32> %x, + %b = add <2 x i32> %a, %y + %c = shl <2 x i32> %b, + ret <2 x i32> %c +} + + + + +define i32 @foo(i32 %x, i32 %y) { + %a = lshr i32 %x, 4 + %b = and i32 %a, 8 + %c = add i32 %b, %y + %d = shl i32 %c, 4 + ret i32 %d +} + +define <2 x i32> @foo_v2i32(<2 x i32> %x, <2 x i32> %y) { + %a = lshr <2 x i32> %x, + %b = and <2 x i32> %a, + %c = add <2 x i32> %b, %y + %d = shl <2 x i32> %c, + ret <2 x i32> %d +} + + + diff --git a/test/Transforms/InstCombine/shift.ll b/test/Transforms/InstCombine/shift.ll index 8f0bbd1487c..5586bb65278 100644 --- a/test/Transforms/InstCombine/shift.ll +++ b/test/Transforms/InstCombine/shift.ll @@ -40,6 +40,27 @@ define i32 @test5(i32 %A) { ret i32 %B } +define <4 x i32> @test5_splat_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5_splat_vector( +; CHECK: ret <4 x i32> undef + %B = lshr <4 x i32> %A, ;; shift all bits out + ret <4 x i32> %B +} + +define <4 x i32> @test5_zero_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5_zero_vector( +; CHECK-NEXT: ret <4 x i32> %A + %B = lshr <4 x i32> %A, zeroinitializer + ret <4 x i32> %B +} + +define <4 x i32> @test5_non_splat_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5_non_splat_vector( +; CHECK-NOT: ret <4 x i32> undef + %B = shl <4 x i32> %A, + ret <4 x i32> %B +} + define i32 @test5a(i32 %A) { ; CHECK-LABEL: @test5a( ; CHECK: ret i32 undef @@ -47,6 +68,20 @@ define i32 @test5a(i32 %A) { ret i32 %B } +define <4 x i32> @test5a_splat_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5a_splat_vector( +; CHECK: ret <4 x i32> undef + %B = shl <4 x i32> %A, ;; shift all bits out + ret <4 x i32> %B +} + +define <4 x i32> @test5a_non_splat_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5a_non_splat_vector( +; CHECK-NOT: ret <4 x i32> undef + %B = shl <4 x i32> %A, + ret <4 x i32> %B +} + define i32 @test5b() { ; CHECK-LABEL: @test5b( ; CHECK: ret i32 -1 @@ -344,6 +379,20 @@ define i32 @test25(i32 %tmp.2, i32 %AA) { ret i32 %tmp.6 } +define <2 x i32> @test25_vector(<2 x i32> %tmp.2, <2 x i32> %AA) { +; CHECK-LABEL: @test25_vector( +; CHECK: %tmp.3 = lshr <2 x i32> %tmp.2, +; CHECK-NEXT: shl <2 x i32> %tmp.3, +; CHECK-NEXT: add <2 x i32> %tmp.51, %AA +; CHECK-NEXT: and <2 x i32> %x2, +; CHECK-NEXT: ret <2 x i32> + %x = lshr <2 x i32> %AA, + %tmp.3 = lshr <2 x i32> %tmp.2, + %tmp.5 = add <2 x i32> %tmp.3, %x + %tmp.6 = shl <2 x i32> %tmp.5, + ret <2 x i32> %tmp.6 +} + ;; handle casts between shifts. define i32 @test26(i32 %A) { ; CHECK-LABEL: @test26( @@ -780,3 +829,32 @@ bb11: ; preds = %bb8 bb12: ; preds = %bb11, %bb8, %bb ret void } + +define i32 @test64(i32 %a) { +; CHECK-LABEL: @test64( +; CHECK-NEXT: ret i32 undef + %b = ashr i32 %a, 32 ; shift all bits out + ret i32 %b +} + +define <4 x i32> @test64_splat_vector(<4 x i32> %a) { +; CHECK-LABEL: @test64_splat_vector +; CHECK-NEXT: ret <4 x i32> undef + %b = ashr <4 x i32> %a, ; shift all bits out + ret <4 x i32> %b +} + +define <4 x i32> @test64_non_splat_vector(<4 x i32> %a) { +; CHECK-LABEL: @test64_non_splat_vector +; CHECK-NOT: ret <4 x i32> undef + %b = ashr <4 x i32> %a, ; shift all bits out + ret <4 x i32> %b +} + +define <2 x i65> @test_65(<2 x i64> %t) { +; CHECK-LABEL: @test_65 + %a = zext <2 x i64> %t to <2 x i65> + %sext = shl <2 x i65> %a, + %b = ashr <2 x i65> %sext, + ret <2 x i65> %b +}