From fb33ce9956e10d74031b54fdaf4e341c64ca6de1 Mon Sep 17 00:00:00 2001 From: Matt Arsenault Date: Fri, 11 Apr 2014 17:57:53 +0000 Subject: [PATCH] Fix shift by constants for vector. ashr , M -> undef git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@206045 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/InstCombine/InstCombine.h | 2 +- .../InstCombine/InstCombineShifts.cpp | 34 +++++--- test/Transforms/InstCombine/shift.ll | 80 +++++++++++++++++-- 3 files changed, 95 insertions(+), 21 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombine.h b/lib/Transforms/InstCombine/InstCombine.h index 822e146ac46..4ee2f59c174 100644 --- a/lib/Transforms/InstCombine/InstCombine.h +++ b/lib/Transforms/InstCombine/InstCombine.h @@ -171,7 +171,7 @@ public: ICmpInst::Predicate Pred); Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I); - Instruction *FoldShiftByConstant(Value *Op0, ConstantInt *Op1, + Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I); Instruction *commonCastTransforms(CastInst &CI); Instruction *commonPointerCastTransforms(CastInst &CI); diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 8273dfd4887..536b148d29c 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -33,7 +33,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - if (ConstantInt *CUI = dyn_cast(Op1)) + if (Constant *CUI = dyn_cast(Op1)) if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; @@ -309,20 +309,30 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, -Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, +Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { bool isLeftShift = I.getOpcode() == Instruction::Shl; + ConstantInt *COp1 = nullptr; + if (ConstantDataVector *CV = dyn_cast(Op1)) + COp1 = dyn_cast_or_null(CV->getSplatValue()); + else if (ConstantVector *CV = dyn_cast(Op1)) + COp1 = dyn_cast_or_null(CV->getSplatValue()); + else + COp1 = dyn_cast(Op1); + + if (!COp1) + return nullptr; // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && - CanEvaluateShifted(Op0, Op1->getZExtValue(), isLeftShift, *this)) { + CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this)) { DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); return ReplaceInstUsesWith(I, - GetShiftedValue(Op0, Op1->getZExtValue(), isLeftShift, *this)); + GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this)); } @@ -333,7 +343,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // shl i32 X, 32 = 0 and srl i8 Y, 9 = 0, ... just don't eliminate // a signed shift. // - if (Op1->uge(TypeBits)) { + if (COp1->uge(TypeBits)) { if (I.getOpcode() != Instruction::AShr) return ReplaceInstUsesWith(I, Constant::getNullValue(Op0->getType())); // ashr i32 X, 32 --> ashr i32 X, 31 @@ -346,7 +356,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, if (BO->getOpcode() == Instruction::Mul && isLeftShift) if (Constant *BOOp = dyn_cast(BO->getOperand(1))) return BinaryOperator::CreateMul(BO->getOperand(0), - ConstantExpr::getShl(BOOp, Op1)); + ConstantExpr::getShl(BOOp, COp1)); // Try to fold constant and into select arguments. if (SelectInst *SI = dyn_cast(Op0)) @@ -367,7 +377,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, if (TrOp && I.isLogicalShift() && TrOp->isShift() && isa(TrOp->getOperand(1))) { // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = ConstantExpr::getZExt(Op1, TrOp->getType()); + Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType()); // (shift2 (shift1 & 0x00FF), c2) Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName()); @@ -384,10 +394,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // shift. We know that it is a logical shift by a constant, so adjust the // mask as appropriate. if (I.getOpcode() == Instruction::Shl) - MaskV <<= Op1->getZExtValue(); + MaskV <<= COp1->getZExtValue(); else { assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); - MaskV = MaskV.lshr(Op1->getZExtValue()); + MaskV = MaskV.lshr(COp1->getZExtValue()); } // shift1 & 0x00FF @@ -421,7 +431,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // (X + (Y << C)) Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1, Op0BO->getOperand(1)->getName()); - uint32_t Op1Val = Op1->getLimitedValue(TypeBits); + uint32_t Op1Val = COp1->getLimitedValue(TypeBits); return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(), APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); } @@ -453,7 +463,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, // (X + (Y << C)) Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS, Op0BO->getOperand(0)->getName()); - uint32_t Op1Val = Op1->getLimitedValue(TypeBits); + uint32_t Op1Val = COp1->getLimitedValue(TypeBits); return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(), APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); } @@ -541,7 +551,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, ConstantInt *ShiftAmt1C = cast(ShiftOp->getOperand(1)); uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); - uint32_t ShiftAmt2 = Op1->getLimitedValue(TypeBits); + uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); if (ShiftAmt1 == 0) return 0; // Will be simplified in the future. Value *X = ShiftOp->getOperand(0); diff --git a/test/Transforms/InstCombine/shift.ll b/test/Transforms/InstCombine/shift.ll index b1082f06ef7..bbfe5c6c1df 100644 --- a/test/Transforms/InstCombine/shift.ll +++ b/test/Transforms/InstCombine/shift.ll @@ -36,17 +36,52 @@ define i32 @test4(i8 %A) { define i32 @test5(i32 %A) { ; CHECK-LABEL: @test5( ; CHECK: ret i32 undef - %B = lshr i32 %A, 32 ;; shift all bits out + %B = lshr i32 %A, 32 ;; shift all bits out ret i32 %B } +define <4 x i32> @test5_splat_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5_splat_vector( +; CHECK: ret <4 x i32> undef + %B = lshr <4 x i32> %A, ;; shift all bits out + ret <4 x i32> %B +} + +define <4 x i32> @test5_zero_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5_zero_vector( +; CHECK-NEXT: ret <4 x i32> %A + %B = lshr <4 x i32> %A, zeroinitializer + ret <4 x i32> %B +} + +define <4 x i32> @test5_non_splat_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5_non_splat_vector( +; CHECK-NOT: ret <4 x i32> undef + %B = shl <4 x i32> %A, + ret <4 x i32> %B +} + define i32 @test5a(i32 %A) { ; CHECK-LABEL: @test5a( ; CHECK: ret i32 undef - %B = shl i32 %A, 32 ;; shift all bits out + %B = shl i32 %A, 32 ;; shift all bits out ret i32 %B } +define <4 x i32> @test5a_splat_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5a_splat_vector( +; CHECK: ret <4 x i32> undef + %B = shl <4 x i32> %A, ;; shift all bits out + ret <4 x i32> %B +} + +define <4 x i32> @test5a_non_splat_vector(<4 x i32> %A) { +; CHECK-LABEL: @test5a_non_splat_vector( +; CHECK-NOT: ret <4 x i32> undef + %B = shl <4 x i32> %A, + ret <4 x i32> %B +} + define i32 @test5b() { ; CHECK-LABEL: @test5b( ; CHECK: ret i32 -1 @@ -82,7 +117,7 @@ define i32 @test6a(i32 %A) { define i32 @test7(i8 %A) { ; CHECK-LABEL: @test7( ; CHECK-NEXT: ret i32 -1 - %shift.upgrd.3 = zext i8 %A to i32 + %shift.upgrd.3 = zext i8 %A to i32 %B = ashr i32 -1, %shift.upgrd.3 ;; Always equal to -1 ret i32 %B } @@ -232,7 +267,7 @@ define i1 @test16(i32 %X) { ; CHECK-NEXT: and i32 %X, 16 ; CHECK-NEXT: icmp ne i32 ; CHECK-NEXT: ret i1 - %tmp.3 = ashr i32 %X, 4 + %tmp.3 = ashr i32 %X, 4 %tmp.6 = and i32 %tmp.3, 1 %tmp.7 = icmp ne i32 %tmp.6, 0 ret i1 %tmp.7 @@ -365,12 +400,12 @@ define i1 @test27(i32 %x) nounwind { %z = trunc i32 %y to i1 ret i1 %z } - + define i8 @test28(i8 %x) { entry: ; CHECK-LABEL: @test28( ; CHECK: icmp slt i8 %x, 0 -; CHECK-NEXT: br i1 +; CHECK-NEXT: br i1 %tmp1 = lshr i8 %x, 7 %cond1 = icmp ne i8 %tmp1, 0 br i1 %cond1, label %bb1, label %bb2 @@ -476,7 +511,7 @@ entry: %ins = or i128 %tmp23, %tmp27 %tmp45 = lshr i128 %ins, 64 ret i128 %tmp45 - + ; CHECK-LABEL: @test36( ; CHECK: %tmp231 = or i128 %B, %A ; CHECK: %ins = and i128 %tmp231, 18446744073709551615 @@ -492,7 +527,7 @@ entry: %tmp45 = lshr i128 %ins, 64 %tmp46 = trunc i128 %tmp45 to i64 ret i64 %tmp46 - + ; CHECK-LABEL: @test37( ; CHECK: %tmp23 = shl nuw nsw i128 %tmp22, 32 ; CHECK: %ins = or i128 %tmp23, %A @@ -780,3 +815,32 @@ bb11: ; preds = %bb8 bb12: ; preds = %bb11, %bb8, %bb ret void } + +define i32 @test64(i32 %a) { +; CHECK-LABEL: @test64( +; CHECK-NEXT: ret i32 undef + %b = ashr i32 %a, 32 ; shift all bits out + ret i32 %b +} + +define <4 x i32> @test64_splat_vector(<4 x i32> %a) { +; CHECK-LABEL: @test64_splat_vector +; CHECK-NEXT: ret <4 x i32> undef + %b = ashr <4 x i32> %a, ; shift all bits out + ret <4 x i32> %b +} + +define <4 x i32> @test64_non_splat_vector(<4 x i32> %a) { +; CHECK-LABEL: @test64_non_splat_vector +; CHECK-NOT: ret <4 x i32> undef + %b = ashr <4 x i32> %a, ; shift all bits out + ret <4 x i32> %b +} + +define <2 x i65> @test_65(<2 x i64> %t) { +; CHECK-LABEL: @test_65 + %a = zext <2 x i64> %t to <2 x i65> + %sext = shl <2 x i65> %a, + %b = ashr <2 x i65> %sext, + ret <2 x i65> %b +}