diff --git a/include/llvm/CodeGen/MachineBranchProbabilityInfo.h b/include/llvm/CodeGen/MachineBranchProbabilityInfo.h index 39d2f32982b..943901f59ff 100644 --- a/include/llvm/CodeGen/MachineBranchProbabilityInfo.h +++ b/include/llvm/CodeGen/MachineBranchProbabilityInfo.h @@ -34,8 +34,10 @@ class MachineBranchProbabilityInfo : public ImmutablePass { // weight to just "inherit" the non-zero weight of an adjacent successor. static const uint32_t DEFAULT_WEIGHT = 16; - // Get sum of the block successors' weights. - uint32_t getSumForBlock(MachineBasicBlock *MBB) const; + // Get sum of the block successors' weights, potentially scaling them to fit + // within 32-bits. If scaling is required, sets Scale based on the necessary + // adjustment. Any edge weights used with the sum should be divided by Scale. + uint32_t getSumForBlock(MachineBasicBlock *MBB, uint32_t &Scale) const; public: static char ID; diff --git a/lib/CodeGen/MachineBranchProbabilityInfo.cpp b/lib/CodeGen/MachineBranchProbabilityInfo.cpp index 987403796f4..0037d525151 100644 --- a/lib/CodeGen/MachineBranchProbabilityInfo.cpp +++ b/lib/CodeGen/MachineBranchProbabilityInfo.cpp @@ -27,19 +27,34 @@ INITIALIZE_PASS_END(MachineBranchProbabilityInfo, "machine-branch-prob", char MachineBranchProbabilityInfo::ID = 0; uint32_t MachineBranchProbabilityInfo:: -getSumForBlock(MachineBasicBlock *MBB) const { - uint32_t Sum = 0; - +getSumForBlock(MachineBasicBlock *MBB, uint32_t &Scale) const { + // First we compute the sum with 64-bits of precision, ensuring that cannot + // overflow by bounding the number of weights considered. Hopefully no one + // actually needs 2^32 successors. + assert(MBB->succ_size() < UINT32_MAX); + uint64_t Sum = 0; + Scale = 1; for (MachineBasicBlock::const_succ_iterator I = MBB->succ_begin(), E = MBB->succ_end(); I != E; ++I) { - MachineBasicBlock *Succ = *I; - uint32_t Weight = getEdgeWeight(MBB, Succ); - uint32_t PrevSum = Sum; - + uint32_t Weight = getEdgeWeight(MBB, *I); Sum += Weight; - assert(Sum > PrevSum); (void) PrevSum; } + // If the computed sum fits in 32-bits, we're done. + if (Sum <= UINT32_MAX) + return Sum; + + // Otherwise, compute the scale necessary to cause the weights to fit, and + // re-sum with that scale applied. + assert((Sum / UINT32_MAX) < UINT32_MAX); + Scale = (Sum / UINT32_MAX) + 1; + Sum = 0; + for (MachineBasicBlock::const_succ_iterator I = MBB->succ_begin(), + E = MBB->succ_end(); I != E; ++I) { + uint32_t Weight = getEdgeWeight(MBB, *I); + Sum += Weight / Scale; + } + assert(Sum <= UINT32_MAX); return Sum; } @@ -89,8 +104,9 @@ MachineBranchProbabilityInfo::getHotSucc(MachineBasicBlock *MBB) const { BranchProbability MachineBranchProbabilityInfo::getEdgeProbability(MachineBasicBlock *Src, MachineBasicBlock *Dst) const { - uint32_t N = getEdgeWeight(Src, Dst); - uint32_t D = getSumForBlock(Src); + uint32_t Scale = 1; + uint32_t D = getSumForBlock(Src, Scale); + uint32_t N = getEdgeWeight(Src, Dst) / Scale; return BranchProbability(N, D); } diff --git a/test/CodeGen/X86/block-placement.ll b/test/CodeGen/X86/block-placement.ll index fed27c63531..e41d52c4183 100644 --- a/test/CodeGen/X86/block-placement.ll +++ b/test/CodeGen/X86/block-placement.ll @@ -271,3 +271,54 @@ loop.body5: %ptr2 = load i32** undef, align 4 br label %loop.body3 } + +define i32 @problematic_switch() { +; This function's CFG caused overlow in the machine branch probability +; calculation, triggering asserts. Make sure we don't crash on it. +; CHECK: problematic_switch + +entry: + switch i32 undef, label %exit [ + i32 879, label %bogus + i32 877, label %step + i32 876, label %step + i32 875, label %step + i32 874, label %step + i32 873, label %step + i32 872, label %step + i32 868, label %step + i32 867, label %step + i32 866, label %step + i32 861, label %step + i32 860, label %step + i32 856, label %step + i32 855, label %step + i32 854, label %step + i32 831, label %step + i32 830, label %step + i32 829, label %step + i32 828, label %step + i32 815, label %step + i32 814, label %step + i32 811, label %step + i32 806, label %step + i32 805, label %step + i32 804, label %step + i32 803, label %step + i32 802, label %step + i32 801, label %step + i32 800, label %step + i32 799, label %step + i32 798, label %step + i32 797, label %step + i32 796, label %step + i32 795, label %step + ] +bogus: + unreachable +step: + br label %exit +exit: + %merge = phi i32 [ 3, %step ], [ 6, %entry ] + ret i32 %merge +}