From 46b13dd8803e0004d11d66cf73f6f24b6d211971 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Mon, 13 Jul 2015 01:15:53 +0000 Subject: [PATCH] [InstSimplify] Teach InstSimplify how to simplify extractelement git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@242008 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Analysis/ConstantFolding.h | 5 ++ include/llvm/Analysis/InstructionSimplify.h | 9 +++ include/llvm/Analysis/VectorUtils.h | 5 ++ lib/Analysis/InstructionSimplify.cpp | 48 +++++++++++++ lib/Analysis/VectorUtils.cpp | 52 ++++++++++++++ .../InstCombine/InstCombineVectorOps.cpp | 67 +++---------------- test/Transforms/InstSimplify/undef.ll | 14 ++++ 7 files changed, 142 insertions(+), 58 deletions(-) diff --git a/include/llvm/Analysis/ConstantFolding.h b/include/llvm/Analysis/ConstantFolding.h index e67e2651c89..e8185b3b630 100644 --- a/include/llvm/Analysis/ConstantFolding.h +++ b/include/llvm/Analysis/ConstantFolding.h @@ -78,6 +78,11 @@ Constant *ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, Constant *ConstantFoldExtractValueInstruction(Constant *Agg, ArrayRef Idxs); +/// \brief Attempt to constant fold an extractelement instruction with the +/// specified operands and indices. The constant result is returned if +/// successful; if not, null is returned. +Constant *ConstantFoldExtractElementInstruction(Constant *Val, Constant *Idx); + /// ConstantFoldLoadFromConstPtr - Return the value that a load from C would /// produce if it is constant and determinable. If this is not determinable, /// return null. diff --git a/include/llvm/Analysis/InstructionSimplify.h b/include/llvm/Analysis/InstructionSimplify.h index 0ed379062bb..d44c5ff4078 100644 --- a/include/llvm/Analysis/InstructionSimplify.h +++ b/include/llvm/Analysis/InstructionSimplify.h @@ -253,6 +253,15 @@ namespace llvm { AssumptionCache *AC = nullptr, const Instruction *CxtI = nullptr); + /// \brief Given operands for an ExtractElementInst, see if we can fold the + /// result. If not, this returns null. + Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, + const DataLayout &DL, + const TargetLibraryInfo *TLI = nullptr, + const DominatorTree *DT = nullptr, + AssumptionCache *AC = nullptr, + const Instruction *CxtI = nullptr); + /// SimplifyTruncInst - Given operands for an TruncInst, see if we can fold /// the result. If not, this returns null. Value *SimplifyTruncInst(Value *Op, Type *Ty, const DataLayout &DL, diff --git a/include/llvm/Analysis/VectorUtils.h b/include/llvm/Analysis/VectorUtils.h index e53c7c0a9d9..d8e9ca42e62 100644 --- a/include/llvm/Analysis/VectorUtils.h +++ b/include/llvm/Analysis/VectorUtils.h @@ -74,6 +74,11 @@ Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty); /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp); +/// \brief Given a vector and an element number, see if the scalar value is +/// already around as a register, for example if it were inserted then extracted +/// from the vector. +Value *findScalarElement(Value *V, unsigned EltNo); + } // llvm namespace #endif diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index ad85a3ed73c..fa42b48b6cd 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -24,6 +24,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" @@ -3555,6 +3556,47 @@ Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef Idxs, RecursionLimit); } +/// SimplifyExtractElementInst - Given operands for an ExtractElementInst, see if we +/// can fold the result. If not, this returns null. +static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &, + unsigned) { + if (auto *CVec = dyn_cast(Vec)) { + if (auto *CIdx = dyn_cast(Idx)) + return ConstantFoldExtractElementInstruction(CVec, CIdx); + + // The index is not relevant if our vector is a splat. + if (auto *Splat = CVec->getSplatValue()) + return Splat; + + if (isa(Vec)) + return UndefValue::get(Vec->getType()->getVectorElementType()); + } + + // If extracting a specified index from the vector, see if we can recursively + // find a previously computed scalar that was inserted into the vector. + if (auto *IdxC = dyn_cast(Idx)) { + unsigned IndexVal = IdxC->getZExtValue(); + unsigned VectorWidth = Vec->getType()->getVectorNumElements(); + + // If this is extracting an invalid index, turn this into undef, to avoid + // crashing the code below. + if (IndexVal >= VectorWidth) + return UndefValue::get(Vec->getType()->getVectorElementType()); + + if (Value *Elt = findScalarElement(Vec, IndexVal)) + return Elt; + } + + return nullptr; +} + +Value *llvm::SimplifyExtractElementInst( + Value *Vec, Value *Idx, const DataLayout &DL, const TargetLibraryInfo *TLI, + const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) { + return ::SimplifyExtractElementInst(Vec, Idx, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); +} + /// SimplifyPHINode - See if we can fold the given phi. If not, returns null. static Value *SimplifyPHINode(PHINode *PN, const Query &Q) { // If all of the PHI's incoming values are the same then replace the PHI node @@ -3970,6 +4012,12 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, EVI->getIndices(), DL, TLI, DT, AC, I); break; } + case Instruction::ExtractElement: { + auto *EEI = cast(I); + Result = SimplifyExtractElementInst( + EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I); + break; + } case Instruction::PHI: Result = SimplifyPHINode(cast(I), Query(DL, TLI, DT, AC, I)); break; diff --git a/lib/Analysis/VectorUtils.cpp b/lib/Analysis/VectorUtils.cpp index eab5887a17e..67f68dc8391 100644 --- a/lib/Analysis/VectorUtils.cpp +++ b/lib/Analysis/VectorUtils.cpp @@ -357,3 +357,55 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE, return Stride; } + +/// \brief Given a vector and an element number, see if the scalar value is +/// already around as a register, for example if it were inserted then extracted +/// from the vector. +llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) { + assert(V->getType()->isVectorTy() && "Not looking at a vector?"); + VectorType *VTy = cast(V->getType()); + unsigned Width = VTy->getNumElements(); + if (EltNo >= Width) // Out of range access. + return UndefValue::get(VTy->getElementType()); + + if (Constant *C = dyn_cast(V)) + return C->getAggregateElement(EltNo); + + if (InsertElementInst *III = dyn_cast(V)) { + // If this is an insert to a variable element, we don't know what it is. + if (!isa(III->getOperand(2))) + return nullptr; + unsigned IIElt = cast(III->getOperand(2))->getZExtValue(); + + // If this is an insert to the element we are looking for, return the + // inserted value. + if (EltNo == IIElt) + return III->getOperand(1); + + // Otherwise, the insertelement doesn't modify the value, recurse on its + // vector input. + return findScalarElement(III->getOperand(0), EltNo); + } + + if (ShuffleVectorInst *SVI = dyn_cast(V)) { + unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); + int InEl = SVI->getMaskValue(EltNo); + if (InEl < 0) + return UndefValue::get(VTy->getElementType()); + if (InEl < (int)LHSWidth) + return findScalarElement(SVI->getOperand(0), InEl); + return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); + } + + // Extract a value from a vector add operation with a constant zero. + Value *Val = nullptr; Constant *Con = nullptr; + if (match(V, + llvm::PatternMatch::m_Add(llvm::PatternMatch::m_Value(Val), + llvm::PatternMatch::m_Constant(Con)))) { + if (Con->getAggregateElement(EltNo)->isNullValue()) + return findScalarElement(Val, EltNo); + } + + // Otherwise, we don't know. + return nullptr; +} diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 24446c8578e..273047279e9 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -14,6 +14,8 @@ #include "InstCombineInternal.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/PatternMatch.h" using namespace llvm; using namespace PatternMatch; @@ -60,56 +62,6 @@ static bool CheapToScalarize(Value *V, bool isConstant) { return false; } -/// FindScalarElement - Given a vector and an element number, see if the scalar -/// value is already around as a register, for example if it were inserted then -/// extracted from the vector. -static Value *FindScalarElement(Value *V, unsigned EltNo) { - assert(V->getType()->isVectorTy() && "Not looking at a vector?"); - VectorType *VTy = cast(V->getType()); - unsigned Width = VTy->getNumElements(); - if (EltNo >= Width) // Out of range access. - return UndefValue::get(VTy->getElementType()); - - if (Constant *C = dyn_cast(V)) - return C->getAggregateElement(EltNo); - - if (InsertElementInst *III = dyn_cast(V)) { - // If this is an insert to a variable element, we don't know what it is. - if (!isa(III->getOperand(2))) - return nullptr; - unsigned IIElt = cast(III->getOperand(2))->getZExtValue(); - - // If this is an insert to the element we are looking for, return the - // inserted value. - if (EltNo == IIElt) - return III->getOperand(1); - - // Otherwise, the insertelement doesn't modify the value, recurse on its - // vector input. - return FindScalarElement(III->getOperand(0), EltNo); - } - - if (ShuffleVectorInst *SVI = dyn_cast(V)) { - unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); - int InEl = SVI->getMaskValue(EltNo); - if (InEl < 0) - return UndefValue::get(VTy->getElementType()); - if (InEl < (int)LHSWidth) - return FindScalarElement(SVI->getOperand(0), InEl); - return FindScalarElement(SVI->getOperand(1), InEl - LHSWidth); - } - - // Extract a value from a vector add operation with a constant zero. - Value *Val = nullptr; Constant *Con = nullptr; - if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) { - if (Con->getAggregateElement(EltNo)->isNullValue()) - return FindScalarElement(Val, EltNo); - } - - // Otherwise, we don't know. - return nullptr; -} - // If we have a PHI node with a vector type that has only 2 uses: feed // itself and be an operand of extractelement at a constant location, // try to replace the PHI of the vector type with a PHI of a scalar type. @@ -178,6 +130,10 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { } Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { + if (Value *V = SimplifyExtractElementInst( + EI.getVectorOperand(), EI.getIndexOperand(), DL, TLI, DT, AC)) + return ReplaceInstUsesWith(EI, V); + // If vector val is constant with all elements the same, replace EI with // that element. We handle a known element # below. if (Constant *C = dyn_cast(EI.getOperand(0))) @@ -190,10 +146,8 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { unsigned IndexVal = IdxC->getZExtValue(); unsigned VectorWidth = EI.getVectorOperandType()->getNumElements(); - // If this is extracting an invalid index, turn this into undef, to avoid - // crashing the code below. - if (IndexVal >= VectorWidth) - return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + // InstSimplify handles cases where the index is invalid. + assert(IndexVal < VectorWidth); // This instruction only demands the single element from the input vector. // If the input vector has a single use, simplify it based on this use @@ -209,16 +163,13 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { } } - if (Value *Elt = FindScalarElement(EI.getOperand(0), IndexVal)) - return ReplaceInstUsesWith(EI, Elt); - // If the this extractelement is directly using a bitcast from a vector of // the same number of elements, see if we can find the source element from // it. In this case, we will end up needing to bitcast the scalars. if (BitCastInst *BCI = dyn_cast(EI.getOperand(0))) { if (VectorType *VT = dyn_cast(BCI->getOperand(0)->getType())) if (VT->getNumElements() == VectorWidth) - if (Value *Elt = FindScalarElement(BCI->getOperand(0), IndexVal)) + if (Value *Elt = findScalarElement(BCI->getOperand(0), IndexVal)) return new BitCastInst(Elt, EI.getType()); } diff --git a/test/Transforms/InstSimplify/undef.ll b/test/Transforms/InstSimplify/undef.ll index f1f0b037fdb..d75dc364243 100644 --- a/test/Transforms/InstSimplify/undef.ll +++ b/test/Transforms/InstSimplify/undef.ll @@ -265,3 +265,17 @@ define i32 @test34(i32 %a) { %b = lshr i32 undef, 0 ret i32 %b } + +; CHECK-LABEL: @test35 +; CHECK: ret i32 undef +define i32 @test35(<4 x i32> %V) { + %b = extractelement <4 x i32> %V, i32 4 + ret i32 %b +} + +; CHECK-LABEL: @test36 +; CHECK: ret i32 undef +define i32 @test36(i32 %V) { + %b = extractelement <4 x i32> undef, i32 %V + ret i32 %b +}