Preserve branch profile metadata during switch formation.

Patch by Michael Ilseman!
This fixes SimplifyCFGOpt::FoldValueComparisonIntoPredecessors to preserve metata when folding conditional branches into switches.

void foo(int x) {
  if (x == 0)
    bar(1);
  else if (__builtin_expect(x == 10, 1))
    bar(2);
  else if (x == 20)
    bar(3);
}

CFG:

B0
|  \
|   X0
B10
|  \
|   X10
B20
|  \
E   X20

Merge B0-B10:
w(B0-X0) = w(B0-X0)*sum-weights(B10) = w(B0-X0) * (w(B10-X10) + w(B10-B20))
w(B0-X10) = w(B0-B10) * w(B10-X10)
w(B0-B20) = w(B0-B10) * w(B10-B20)

B0 __
| \  \
| X10 X0
B20
|  \
E  X20

Merge B0-B20:
w(B0-X0) = w(B0-X0) * sum-weights(B20) = w(B0-X0) * (w(B20-E) + w(B20-X20))
w(B0-X10) = w(B0-X10) * sum-weights(B20) = ...
w(B0-X20) = w(B0-B20) * w(B20-X20)
w(B0-E) = w(B0-B20) * w(B20-E)

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@162868 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Andrew Trick 2012-08-29 21:46:38 +00:00
parent 6b01438dec
commit b1b97833ae

View File

@ -615,6 +615,9 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI,
assert(ThisVal && "This isn't a value comparison!!");
if (ThisVal != PredVal) return false; // Different predicates.
// TODO: Preserve branch weight metadata, similarly to how
// FoldValueComparisonIntoPredecessors preserves it.
// Find out information about when control will move from Pred to TI's block.
std::vector<ValueEqualityComparisonCase> PredCases;
BasicBlock *PredDef = GetValueEqualityComparisonCases(Pred->getTerminator(),
@ -738,6 +741,67 @@ static int ConstantIntSortPredicate(const void *P1, const void *P2) {
return -1;
}
static inline bool HasBranchWeights(const Instruction* I) {
MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof);
if (ProfMD && ProfMD->getOperand(0))
if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0)))
return MDS->getString().equals("branch_weights");
return false;
}
/// Tries to get a branch weight for the given instruction, returns NULL if it
/// can't. Pos starts at 0.
static ConstantInt* GetWeight(Instruction* I, int Pos) {
MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof);
if (ProfMD && ProfMD->getOperand(0)) {
if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) {
if (MDS->getString().equals("branch_weights")) {
assert(ProfMD->getNumOperands() >= 3);
return dyn_cast<ConstantInt>(ProfMD->getOperand(1 + Pos));
}
}
}
return 0;
}
/// Scale the given weights based on the new TI's metadata. Scaling is done by
/// multiplying every weight by the sum of the successor's weights.
static void ScaleWeights(Instruction* STI, MutableArrayRef<uint64_t> Weights) {
// Sum the successor's weights
assert(HasBranchWeights(STI));
unsigned Scale = 0;
MDNode* ProfMD = STI->getMetadata(LLVMContext::MD_prof);
for (unsigned i = 1; i < ProfMD->getNumOperands(); ++i) {
ConstantInt* CI = dyn_cast<ConstantInt>(ProfMD->getOperand(i));
assert(CI);
Scale += CI->getValue().getZExtValue();
}
// Skip default, as it's replaced during the folding
for (unsigned i = 1; i < Weights.size(); ++i) {
Weights[i] *= Scale;
}
}
/// Sees if any of the weights are too big for a uint32_t, and halves all the
/// weights if any are.
static void FitWeights(MutableArrayRef<uint64_t> Weights) {
bool Halve = false;
for (unsigned i = 0; i < Weights.size(); ++i)
if (Weights[i] > UINT_MAX) {
Halve = true;
break;
}
if (! Halve)
return;
for (unsigned i = 0; i < Weights.size(); ++i)
Weights[i] /= 2;
}
/// FoldValueComparisonIntoPredecessors - The specified terminator is a value
/// equality comparison instruction (either a switch or a branch on "X == c").
/// See if any of the predecessors of the terminator block are value comparisons
@ -770,6 +834,55 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,
// build.
SmallVector<BasicBlock*, 8> NewSuccessors;
// Update the branch weight metadata along the way
SmallVector<uint64_t, 8> Weights;
uint64_t PredDefaultWeight = 0;
bool PredHasWeights = HasBranchWeights(PTI);
bool SuccHasWeights = HasBranchWeights(TI);
if (PredHasWeights) {
MDNode* MD = PTI->getMetadata(LLVMContext::MD_prof);
assert(MD);
for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(i));
assert(CI);
Weights.push_back(CI->getValue().getZExtValue());
}
// If the predecessor is a conditional eq, then swap the default weight
// to be the first entry.
if (BranchInst* BI = dyn_cast<BranchInst>(PTI)) {
assert(Weights.size() == 2);
ICmpInst *ICI = cast<ICmpInst>(BI->getCondition());
if (ICI->getPredicate() == ICmpInst::ICMP_EQ) {
std::swap(Weights.front(), Weights.back());
}
}
PredDefaultWeight = Weights.front();
} else if (SuccHasWeights) {
// If there are no predecessor weights but there are successor weights,
// populate Weights with 1, which will later be scaled to the sum of
// successor's weights
Weights.assign(1 + PredCases.size(), 1);
PredDefaultWeight = 1;
}
uint64_t SuccDefaultWeight = 0;
if (SuccHasWeights) {
int Index = 0;
if (BranchInst* BI = dyn_cast<BranchInst>(TI)) {
ICmpInst* ICI = dyn_cast<ICmpInst>(BI->getCondition());
assert(ICI);
if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
Index = 1;
}
SuccDefaultWeight = GetWeight(TI, Index)->getValue().getZExtValue();
}
if (PredDefault == BB) {
// If this is the default destination from PTI, only the edges in TI
// that don't occur in PTI, or that branch to BB will be activated.
@ -780,6 +893,12 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,
else {
// The default destination is BB, we don't need explicit targets.
std::swap(PredCases[i], PredCases.back());
if (PredHasWeights) {
std::swap(Weights[i+1], Weights.back());
Weights.pop_back();
}
PredCases.pop_back();
--i; --e;
}
@ -790,14 +909,35 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,
PredDefault = BBDefault;
NewSuccessors.push_back(BBDefault);
}
if (SuccHasWeights) {
ScaleWeights(TI, Weights);
Weights.front() *= SuccDefaultWeight;
} else if (PredHasWeights) {
Weights.front() /= (1 + BBCases.size());
}
for (unsigned i = 0, e = BBCases.size(); i != e; ++i)
if (!PTIHandled.count(BBCases[i].Value) &&
BBCases[i].Dest != BBDefault) {
PredCases.push_back(BBCases[i]);
NewSuccessors.push_back(BBCases[i].Dest);
if (SuccHasWeights) {
Weights.push_back(PredDefaultWeight *
GetWeight(TI, i)->getValue().getZExtValue());
} else if (PredHasWeights) {
// Split the old default's weight amongst the children
assert(PredDefaultWeight != 0);
Weights.push_back(PredDefaultWeight / (1 + BBCases.size()));
}
}
} else {
// FIXME: preserve branch weight metadata, similarly to the 'then'
// above. For now, drop it.
PredHasWeights = false;
SuccHasWeights = false;
// If this is not the default destination from PSI, only the edges
// in SI that occur in PSI with a destination of BB will be
// activated.
@ -851,6 +991,17 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI,
for (unsigned i = 0, e = PredCases.size(); i != e; ++i)
NewSI->addCase(PredCases[i].Value, PredCases[i].Dest);
if (PredHasWeights || SuccHasWeights) {
// Halve the weights if any of them cannot fit in an uint32_t
FitWeights(Weights);
SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
NewSI->setMetadata(LLVMContext::MD_prof,
MDBuilder(BB->getContext()).
createBranchWeights(MDWeights));
}
EraseTerminatorInstAndDCECond(PTI);
// Okay, last check. If BB is still a successor of PSI, then we must
@ -2349,6 +2500,9 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const TargetData *TD,
// transformation. A switch with one value is just an cond branch.
if (ExtraCase && Values.size() < 2) return false;
// TODO: Preserve branch weight metadata, similarly to how
// FoldValueComparisonIntoPredecessors preserves it.
// Figure out which block is which destination.
BasicBlock *DefaultBB = BI->getSuccessor(1);
BasicBlock *EdgeBB = BI->getSuccessor(0);