diff --git a/include/llvm/Analysis/TargetTransformInfo.h b/include/llvm/Analysis/TargetTransformInfo.h index f57f3eb009a..0ec49c2b18a 100644 --- a/include/llvm/Analysis/TargetTransformInfo.h +++ b/include/llvm/Analysis/TargetTransformInfo.h @@ -416,6 +416,13 @@ public: virtual unsigned getAddressComputationCost(Type *Ty, bool IsComplex = false) const; + /// \returns The cost, if any, of keeping values of the given types alive + /// over a callsite. + /// + /// Some types may require the use of register classes that do not have + /// any callee-saved registers, so would require a spill and fill. + virtual unsigned getCostOfKeepingLiveOverCall(ArrayRef Tys) const; + /// @} /// Analysis group identification. diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index 888b5cef2f8..7ac22303deb 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -230,6 +230,11 @@ unsigned TargetTransformInfo::getReductionCost(unsigned Opcode, Type *Ty, return PrevTTI->getReductionCost(Opcode, Ty, IsPairwise); } +unsigned TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef Tys) + const { + return PrevTTI->getCostOfKeepingLiveOverCall(Tys); +} + namespace { struct NoTTI final : ImmutablePass, TargetTransformInfo { @@ -613,6 +618,11 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { unsigned getReductionCost(unsigned, Type *, bool) const override { return 1; } + + unsigned getCostOfKeepingLiveOverCall(ArrayRef Tys) const override { + return 0; + } + }; } // end anonymous namespace diff --git a/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index b1c931e96e8..2058dd06b21 100644 --- a/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -124,6 +124,9 @@ public: unsigned getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace) const override; + + unsigned getCostOfKeepingLiveOverCall(ArrayRef Tys) const override; + /// @} }; @@ -498,3 +501,15 @@ unsigned AArch64TTI::getMemoryOpCost(unsigned Opcode, Type *Src, return LT.first; } + +unsigned AArch64TTI::getCostOfKeepingLiveOverCall(ArrayRef Tys) const { + unsigned Cost = 0; + for (auto *I : Tys) { + if (!I->isVectorTy()) + continue; + if (I->getScalarSizeInBits() * I->getVectorNumElements() == 128) + Cost += getMemoryOpCost(Instruction::Store, I, 128, 0) + + getMemoryOpCost(Instruction::Load, I, 128, 0); + } + return Cost; +} diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index c91ca280033..d73e746d1ea 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -361,6 +361,10 @@ public: /// Returns the vectorized root. Value *vectorizeTree(); + /// \returns the cost incurred by unwanted spills and fills, caused by + /// holding live values over call sites. + int getSpillCost(); + /// \returns the vectorization cost of the subtree that starts at \p VL. /// A negative number means that this is profitable. int getTreeCost(); @@ -1543,6 +1547,68 @@ bool BoUpSLP::isFullyVectorizableTinyTree() { return true; } +int BoUpSLP::getSpillCost() { + // Walk from the bottom of the tree to the top, tracking which values are + // live. When we see a call instruction that is not part of our tree, + // query TTI to see if there is a cost to keeping values live over it + // (for example, if spills and fills are required). + unsigned BundleWidth = VectorizableTree.front().Scalars.size(); + int Cost = 0; + + SmallPtrSet LiveValues; + Instruction *PrevInst = nullptr; + + for (unsigned N = 0; N < VectorizableTree.size(); ++N) { + Instruction *Inst = dyn_cast(VectorizableTree[N].Scalars[0]); + if (!Inst) + continue; + + if (!PrevInst) { + PrevInst = Inst; + continue; + } + + DEBUG( + dbgs() << "SLP: #LV: " << LiveValues.size(); + for (auto *X : LiveValues) + dbgs() << " " << X->getName(); + dbgs() << ", Looking at "; + Inst->dump(); + ); + + // Update LiveValues. + LiveValues.erase(PrevInst); + for (auto &J : PrevInst->operands()) { + if (isa(&*J) && ScalarToTreeEntry.count(&*J)) + LiveValues.insert(cast(&*J)); + } + + // Now find the sequence of instructions between PrevInst and Inst. + BasicBlock::reverse_iterator InstIt(Inst), PrevInstIt(PrevInst); + --PrevInstIt; + while (InstIt != PrevInstIt) { + if (PrevInstIt == PrevInst->getParent()->rend()) { + PrevInstIt = Inst->getParent()->rbegin(); + continue; + } + + if (isa(&*PrevInstIt) && &*PrevInstIt != PrevInst) { + SmallVector V; + for (auto *II : LiveValues) + V.push_back(VectorType::get(II->getType(), BundleWidth)); + Cost += TTI->getCostOfKeepingLiveOverCall(V); + } + + ++PrevInstIt; + } + + PrevInst = Inst; + } + + DEBUG(dbgs() << "SLP: SpillCost=" << Cost << "\n"); + return Cost; +} + int BoUpSLP::getTreeCost() { int Cost = 0; DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << @@ -1578,6 +1644,8 @@ int BoUpSLP::getTreeCost() { I->Lane); } + Cost += getSpillCost(); + DEBUG(dbgs() << "SLP: Total Cost " << Cost + ExtractCost<< ".\n"); return Cost + ExtractCost; } diff --git a/test/Transforms/SLPVectorizer/AArch64/load-store-q.ll b/test/Transforms/SLPVectorizer/AArch64/load-store-q.ll new file mode 100644 index 00000000000..45fa2f917f3 --- /dev/null +++ b/test/Transforms/SLPVectorizer/AArch64/load-store-q.ll @@ -0,0 +1,46 @@ +; RUN: opt -S -basicaa -slp-vectorizer < %s | FileCheck %s +target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" +target triple = "arm64-apple-ios5.0.0" + +; Holding a value live over a call boundary may require +; spills and fills. This is the case for <2 x double>, +; as it occupies a Q register of which there are no +; callee-saves. + +; CHECK: load double +; CHECK: load double +; CHECK: call void @g +; CHECK: store double +; CHECK: store double +define void @f(double* %p, double* %q) { + %addr2 = getelementptr double* %q, i32 1 + %addr = getelementptr double* %p, i32 1 + %x = load double* %p + %y = load double* %addr + call void @g() + store double %x, double* %q + store double %y, double* %addr2 + ret void +} +declare void @g() + +; Check we deal with loops correctly. +; +; CHECK: store <2 x double> +; CHECK: load <2 x double> +define void @f2(double* %p, double* %q) { +entry: + br label %loop + +loop: + %p1 = phi double [0.0, %entry], [%x, %loop] + %p2 = phi double [0.0, %entry], [%y, %loop] + %addr2 = getelementptr double* %q, i32 1 + %addr = getelementptr double* %p, i32 1 + store double %p1, double* %q + store double %p2, double* %addr2 + + %x = load double* %p + %y = load double* %addr + br label %loop +}