Revert "Revert r206045, "Fix shift by constants for vector.""

Fix cases where the Value itself is used, and not the constant value.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@206214 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Matt Arsenault 2014-04-14 21:50:37 +00:00
parent 05620e5439
commit 448a1a0734
4 changed files with 179 additions and 16 deletions

View File

@ -171,7 +171,7 @@ public:
ICmpInst::Predicate Pred); ICmpInst::Predicate Pred);
Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
ICmpInst::Predicate Cond, Instruction &I); ICmpInst::Predicate Cond, Instruction &I);
Instruction *FoldShiftByConstant(Value *Op0, ConstantInt *Op1, Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I); BinaryOperator &I);
Instruction *commonCastTransforms(CastInst &CI); Instruction *commonCastTransforms(CastInst &CI);
Instruction *commonPointerCastTransforms(CastInst &CI); Instruction *commonPointerCastTransforms(CastInst &CI);

View File

@ -33,7 +33,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
if (Instruction *R = FoldOpIntoSelect(I, SI)) if (Instruction *R = FoldOpIntoSelect(I, SI))
return R; return R;
if (ConstantInt *CUI = dyn_cast<ConstantInt>(Op1)) if (Constant *CUI = dyn_cast<Constant>(Op1))
if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
return Res; return Res;
@ -309,20 +309,30 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I) { BinaryOperator &I) {
bool isLeftShift = I.getOpcode() == Instruction::Shl; bool isLeftShift = I.getOpcode() == Instruction::Shl;
ConstantInt *COp1 = nullptr;
if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1))
COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1))
COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
else
COp1 = dyn_cast<ConstantInt>(Op1);
if (!COp1)
return nullptr;
// See if we can propagate this shift into the input, this covers the trivial // See if we can propagate this shift into the input, this covers the trivial
// cast of lshr(shl(x,c1),c2) as well as other more complex cases. // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
if (I.getOpcode() != Instruction::AShr && if (I.getOpcode() != Instruction::AShr &&
CanEvaluateShifted(Op0, Op1->getZExtValue(), isLeftShift, *this)) { CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this)) {
DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
" to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n");
return ReplaceInstUsesWith(I, return ReplaceInstUsesWith(I,
GetShiftedValue(Op0, Op1->getZExtValue(), isLeftShift, *this)); GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this));
} }
@ -333,7 +343,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
// shl i32 X, 32 = 0 and srl i8 Y, 9 = 0, ... just don't eliminate // shl i32 X, 32 = 0 and srl i8 Y, 9 = 0, ... just don't eliminate
// a signed shift. // a signed shift.
// //
if (Op1->uge(TypeBits)) { if (COp1->uge(TypeBits)) {
if (I.getOpcode() != Instruction::AShr) if (I.getOpcode() != Instruction::AShr)
return ReplaceInstUsesWith(I, Constant::getNullValue(Op0->getType())); return ReplaceInstUsesWith(I, Constant::getNullValue(Op0->getType()));
// ashr i32 X, 32 --> ashr i32 X, 31 // ashr i32 X, 32 --> ashr i32 X, 31
@ -367,7 +377,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
if (TrOp && I.isLogicalShift() && TrOp->isShift() && if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
isa<ConstantInt>(TrOp->getOperand(1))) { isa<ConstantInt>(TrOp->getOperand(1))) {
// Okay, we'll do this xform. Make the shift of shift. // Okay, we'll do this xform. Make the shift of shift.
Constant *ShAmt = ConstantExpr::getZExt(Op1, TrOp->getType()); Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType());
// (shift2 (shift1 & 0x00FF), c2) // (shift2 (shift1 & 0x00FF), c2)
Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName()); Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName());
@ -384,10 +394,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
// shift. We know that it is a logical shift by a constant, so adjust the // shift. We know that it is a logical shift by a constant, so adjust the
// mask as appropriate. // mask as appropriate.
if (I.getOpcode() == Instruction::Shl) if (I.getOpcode() == Instruction::Shl)
MaskV <<= Op1->getZExtValue(); MaskV <<= COp1->getZExtValue();
else { else {
assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
MaskV = MaskV.lshr(Op1->getZExtValue()); MaskV = MaskV.lshr(COp1->getZExtValue());
} }
// shift1 & 0x00FF // shift1 & 0x00FF
@ -421,9 +431,13 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
// (X + (Y << C)) // (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1, Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1,
Op0BO->getOperand(1)->getName()); Op0BO->getOperand(1)->getName());
uint32_t Op1Val = Op1->getLimitedValue(TypeBits); uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(),
APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
return BinaryOperator::CreateAnd(X, Mask);
} }
// Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C))
@ -453,9 +467,13 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
// (X + (Y << C)) // (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS, Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS,
Op0BO->getOperand(0)->getName()); Op0BO->getOperand(0)->getName());
uint32_t Op1Val = Op1->getLimitedValue(TypeBits); uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(),
APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
return BinaryOperator::CreateAnd(X, Mask);
} }
// Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C)
@ -541,7 +559,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1));
uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits);
uint32_t ShiftAmt2 = Op1->getLimitedValue(TypeBits); uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits);
assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
if (ShiftAmt1 == 0) return 0; // Will be simplified in the future. if (ShiftAmt1 == 0) return 0; // Will be simplified in the future.
Value *X = ShiftOp->getOperand(0); Value *X = ShiftOp->getOperand(0);

View File

@ -0,0 +1,67 @@
; RUN: opt -S -instcombine < %s | FileCheck %s
; CHECK-LABEL: @test_FoldShiftByConstant_CreateSHL
; CHECK: mul <4 x i32> %in, <i32 0, i32 -32, i32 0, i32 -32>
; CHECK-NEXT: ret
define <4 x i32> @test_FoldShiftByConstant_CreateSHL(<4 x i32> %in) {
%mul.i = mul <4 x i32> %in, <i32 0, i32 -1, i32 0, i32 -1>
%vshl_n = shl <4 x i32> %mul.i, <i32 5, i32 5, i32 5, i32 5>
ret <4 x i32> %vshl_n
}
; CHECK-LABEL: @test_FoldShiftByConstant_CreateSHL2
; CHECK: mul <8 x i16> %in, <i16 0, i16 -32, i16 0, i16 -32, i16 0, i16 -32, i16 0, i16 -32>
; CHECK-NEXT: ret
define <8 x i16> @test_FoldShiftByConstant_CreateSHL2(<8 x i16> %in) {
%mul.i = mul <8 x i16> %in, <i16 0, i16 -1, i16 0, i16 -1, i16 0, i16 -1, i16 0, i16 -1>
%vshl_n = shl <8 x i16> %mul.i, <i16 5, i16 5, i16 5, i16 5, i16 5, i16 5, i16 5, i16 5>
ret <8 x i16> %vshl_n
}
; CHECK-LABEL: @test_FoldShiftByConstant_CreateAnd
; CHECK: mul <16 x i8> %in0, <i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33, i8 33>
; CHECK-NEXT: and <16 x i8> %vsra_n2, <i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32, i8 -32>
; CHECK-NEXT: ret
define <16 x i8> @test_FoldShiftByConstant_CreateAnd(<16 x i8> %in0) {
%vsra_n = ashr <16 x i8> %in0, <i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5>
%tmp = add <16 x i8> %in0, %vsra_n
%vshl_n = shl <16 x i8> %tmp, <i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5, i8 5>
ret <16 x i8> %vshl_n
}
define i32 @bar(i32 %x, i32 %y) {
%a = lshr i32 %x, 4
%b = add i32 %a, %y
%c = shl i32 %b, 4
ret i32 %c
}
define <2 x i32> @bar_v2i32(<2 x i32> %x, <2 x i32> %y) {
%a = lshr <2 x i32> %x, <i32 5, i32 5>
%b = add <2 x i32> %a, %y
%c = shl <2 x i32> %b, <i32 5, i32 5>
ret <2 x i32> %c
}
define i32 @foo(i32 %x, i32 %y) {
%a = lshr i32 %x, 4
%b = and i32 %a, 8
%c = add i32 %b, %y
%d = shl i32 %c, 4
ret i32 %d
}
define <2 x i32> @foo_v2i32(<2 x i32> %x, <2 x i32> %y) {
%a = lshr <2 x i32> %x, <i32 4, i32 4>
%b = and <2 x i32> %a, <i32 8, i32 8>
%c = add <2 x i32> %b, %y
%d = shl <2 x i32> %c, <i32 4, i32 4>
ret <2 x i32> %d
}

View File

@ -40,6 +40,27 @@ define i32 @test5(i32 %A) {
ret i32 %B ret i32 %B
} }
define <4 x i32> @test5_splat_vector(<4 x i32> %A) {
; CHECK-LABEL: @test5_splat_vector(
; CHECK: ret <4 x i32> undef
%B = lshr <4 x i32> %A, <i32 32, i32 32, i32 32, i32 32> ;; shift all bits out
ret <4 x i32> %B
}
define <4 x i32> @test5_zero_vector(<4 x i32> %A) {
; CHECK-LABEL: @test5_zero_vector(
; CHECK-NEXT: ret <4 x i32> %A
%B = lshr <4 x i32> %A, zeroinitializer
ret <4 x i32> %B
}
define <4 x i32> @test5_non_splat_vector(<4 x i32> %A) {
; CHECK-LABEL: @test5_non_splat_vector(
; CHECK-NOT: ret <4 x i32> undef
%B = shl <4 x i32> %A, <i32 32, i32 1, i32 2, i32 3>
ret <4 x i32> %B
}
define i32 @test5a(i32 %A) { define i32 @test5a(i32 %A) {
; CHECK-LABEL: @test5a( ; CHECK-LABEL: @test5a(
; CHECK: ret i32 undef ; CHECK: ret i32 undef
@ -47,6 +68,20 @@ define i32 @test5a(i32 %A) {
ret i32 %B ret i32 %B
} }
define <4 x i32> @test5a_splat_vector(<4 x i32> %A) {
; CHECK-LABEL: @test5a_splat_vector(
; CHECK: ret <4 x i32> undef
%B = shl <4 x i32> %A, <i32 32, i32 32, i32 32, i32 32> ;; shift all bits out
ret <4 x i32> %B
}
define <4 x i32> @test5a_non_splat_vector(<4 x i32> %A) {
; CHECK-LABEL: @test5a_non_splat_vector(
; CHECK-NOT: ret <4 x i32> undef
%B = shl <4 x i32> %A, <i32 32, i32 1, i32 2, i32 3>
ret <4 x i32> %B
}
define i32 @test5b() { define i32 @test5b() {
; CHECK-LABEL: @test5b( ; CHECK-LABEL: @test5b(
; CHECK: ret i32 -1 ; CHECK: ret i32 -1
@ -344,6 +379,20 @@ define i32 @test25(i32 %tmp.2, i32 %AA) {
ret i32 %tmp.6 ret i32 %tmp.6
} }
define <2 x i32> @test25_vector(<2 x i32> %tmp.2, <2 x i32> %AA) {
; CHECK-LABEL: @test25_vector(
; CHECK: %tmp.3 = lshr <2 x i32> %tmp.2, <i32 17, i32 17>
; CHECK-NEXT: shl <2 x i32> %tmp.3, <i32 17, i32 17>
; CHECK-NEXT: add <2 x i32> %tmp.51, %AA
; CHECK-NEXT: and <2 x i32> %x2, <i32 -131072, i32 -131072>
; CHECK-NEXT: ret <2 x i32>
%x = lshr <2 x i32> %AA, <i32 17, i32 17>
%tmp.3 = lshr <2 x i32> %tmp.2, <i32 17, i32 17>
%tmp.5 = add <2 x i32> %tmp.3, %x
%tmp.6 = shl <2 x i32> %tmp.5, <i32 17, i32 17>
ret <2 x i32> %tmp.6
}
;; handle casts between shifts. ;; handle casts between shifts.
define i32 @test26(i32 %A) { define i32 @test26(i32 %A) {
; CHECK-LABEL: @test26( ; CHECK-LABEL: @test26(
@ -780,3 +829,32 @@ bb11: ; preds = %bb8
bb12: ; preds = %bb11, %bb8, %bb bb12: ; preds = %bb11, %bb8, %bb
ret void ret void
} }
define i32 @test64(i32 %a) {
; CHECK-LABEL: @test64(
; CHECK-NEXT: ret i32 undef
%b = ashr i32 %a, 32 ; shift all bits out
ret i32 %b
}
define <4 x i32> @test64_splat_vector(<4 x i32> %a) {
; CHECK-LABEL: @test64_splat_vector
; CHECK-NEXT: ret <4 x i32> undef
%b = ashr <4 x i32> %a, <i32 32, i32 32, i32 32, i32 32> ; shift all bits out
ret <4 x i32> %b
}
define <4 x i32> @test64_non_splat_vector(<4 x i32> %a) {
; CHECK-LABEL: @test64_non_splat_vector
; CHECK-NOT: ret <4 x i32> undef
%b = ashr <4 x i32> %a, <i32 32, i32 0, i32 1, i32 2> ; shift all bits out
ret <4 x i32> %b
}
define <2 x i65> @test_65(<2 x i64> %t) {
; CHECK-LABEL: @test_65
%a = zext <2 x i64> %t to <2 x i65>
%sext = shl <2 x i65> %a, <i65 33, i65 33>
%b = ashr <2 x i65> %sext, <i65 33, i65 33>
ret <2 x i65> %b
}