From 01e223e92e62d284c46299b34d7b30a0b1bb8aae Mon Sep 17 00:00:00 2001 From: Hans Wennborg Date: Fri, 23 Jan 2015 20:43:51 +0000 Subject: [PATCH] LowerSwitch: replace unreachable default with popular case destination SimplifyCFG currently does this transformation, but I'm planning to remove that to allow other passes, such as this one, to exploit the unreachable default. This patch takes care to keep track of what case values are unreachable even after the transformation, allowing for more efficient lowering. Differential Revision: http://reviews.llvm.org/D6697 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@226934 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Utils/LowerSwitch.cpp | 194 ++++++++++++------ .../2014-06-11-SwitchDefaultUnreachableOpt.ll | 7 +- ...old-popular-case-to-unreachable-default.ll | 110 ++++++++++ 3 files changed, 248 insertions(+), 63 deletions(-) create mode 100644 test/Transforms/LowerSwitch/fold-popular-case-to-unreachable-default.ll diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index 35cd917330a..0a1e138e1b3 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -32,6 +32,23 @@ using namespace llvm; #define DEBUG_TYPE "lower-switch" namespace { + struct IntRange { + int64_t Low, High; + }; + // Return true iff R is covered by Ranges. + static bool IsInRanges(const IntRange &R, + const std::vector &Ranges) { + // Note: Ranges must be sorted, non-overlapping and non-adjacent. + + // Find the first range whose High field is >= R.High, + // then check if the Low field is <= R.Low. If so, we + // have a Range that covers R. + auto I = std::lower_bound( + Ranges.begin(), Ranges.end(), R, + [](const IntRange &A, const IntRange &B) { return A.High < B.High; }); + return I != Ranges.end() && I->Low <= R.Low; + } + /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch /// instructions. class LowerSwitch : public FunctionPass { @@ -68,7 +85,8 @@ namespace { BasicBlock *switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, ConstantInt *UpperBound, Value *Val, BasicBlock *Predecessor, - BasicBlock *OrigBlock, BasicBlock *Default); + BasicBlock *OrigBlock, BasicBlock *Default, + const std::vector &UnreachableRanges); BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, BasicBlock *Default); unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); @@ -172,12 +190,12 @@ static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, // LowerBound and UpperBound are used to keep track of the bounds for Val // that have already been checked by a block emitted by one of the previous // calls to switchConvert in the call stack. -BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, - ConstantInt *LowerBound, - ConstantInt *UpperBound, Value *Val, - BasicBlock *Predecessor, - BasicBlock *OrigBlock, - BasicBlock *Default) { +BasicBlock * +LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, + ConstantInt *UpperBound, Value *Val, + BasicBlock *Predecessor, BasicBlock *OrigBlock, + BasicBlock *Default, + const std::vector &UnreachableRanges) { unsigned Size = End - Begin; if (Size == 1) { @@ -212,19 +230,19 @@ BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, // the smallest, so there is always a case range that has at least // a smaller value. ConstantInt *NewLowerBound = cast(Pivot.Low); - ConstantInt *NewUpperBound; - // If we don't have a Default block then it means that we can never - // have a value outside of a case range, so set the UpperBound to the highest - // value in the LHS part of the case ranges. - if (Default != nullptr) { - // Because NewLowerBound is never the smallest representable integer - // it is safe here to subtract one. - NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), - NewLowerBound->getValue() - 1); - } else { - CaseItr LastLHS = LHS.begin() + LHS.size() - 1; - NewUpperBound = cast(LastLHS->High); + // Because NewLowerBound is never the smallest representable integer + // it is safe here to subtract one. + ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), + NewLowerBound->getValue() - 1); + + if (!UnreachableRanges.empty()) { + // Check if the gap between LHS's highest and NewLowerBound is unreachable. + int64_t GapLow = cast(LHS.back().High)->getSExtValue() + 1; + int64_t GapHigh = NewLowerBound->getSExtValue() - 1; + IntRange Gap = { GapLow, GapHigh }; + if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) + NewUpperBound = cast(LHS.back().High); } DEBUG(dbgs() << "LHS Bounds ==> "; @@ -252,10 +270,10 @@ BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, NewUpperBound, Val, NewNode, OrigBlock, - Default); + Default, UnreachableRanges); BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, UpperBound, Val, NewNode, OrigBlock, - Default); + Default, UnreachableRanges); Function::iterator FI = OrigBlock; F->getBasicBlockList().insert(++FI, NewNode); @@ -380,26 +398,102 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI) { Value *Val = SI->getCondition(); // The value we are switching on... BasicBlock* Default = SI->getDefaultDest(); - // If there is only the default destination, don't bother with the code below. + // If there is only the default destination, just branch. if (!SI->getNumCases()) { - BranchInst::Create(SI->getDefaultDest(), CurBlock); - CurBlock->getInstList().erase(SI); + BranchInst::Create(Default, CurBlock); + SI->eraseFromParent(); return; } - const bool DefaultIsUnreachable = - Default->size() == 1 && isa(Default->getTerminator()); + // Prepare cases vector. + CaseVector Cases; + unsigned numCmps = Clusterify(Cases, SI); + DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() + << ". Total compares: " << numCmps << "\n"); + DEBUG(dbgs() << "Cases: " << Cases << "\n"); + (void)numCmps; + + ConstantInt *LowerBound = nullptr; + ConstantInt *UpperBound = nullptr; + std::vector UnreachableRanges; + + if (isa(Default->getFirstNonPHIOrDbg())) { + // Make the bounds tightly fitted around the case value range, becase we + // know that the value passed to the switch must be exactly one of the case + // values. + assert(!Cases.empty()); + LowerBound = cast(Cases.front().Low); + UpperBound = cast(Cases.back().High); + + DenseMap Popularity; + unsigned MaxPop = 0; + BasicBlock *PopSucc = nullptr; + + IntRange R = { INT64_MIN, INT64_MAX }; + UnreachableRanges.push_back(R); + for (const auto &I : Cases) { + int64_t Low = cast(I.Low)->getSExtValue(); + int64_t High = cast(I.High)->getSExtValue(); + + IntRange &LastRange = UnreachableRanges.back(); + if (LastRange.Low == Low) { + // There is nothing left of the previous range. + UnreachableRanges.pop_back(); + } else { + // Terminate the previous range. + assert(Low > LastRange.Low); + LastRange.High = Low - 1; + } + if (High != INT64_MAX) { + IntRange R = { High + 1, INT64_MAX }; + UnreachableRanges.push_back(R); + } + + // Count popularity. + int64_t N = High - Low + 1; + unsigned &Pop = Popularity[I.BB]; + if ((Pop += N) > MaxPop) { + MaxPop = Pop; + PopSucc = I.BB; + } + } +#ifndef NDEBUG + /* UnreachableRanges should be sorted and the ranges non-adjacent. */ + for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end(); + I != E; ++I) { + assert(I->Low <= I->High); + auto Next = I + 1; + if (Next != E) { + assert(Next->Low > I->High); + } + } +#endif + + // Use the most popular block as the new default, reducing the number of + // cases. + assert(MaxPop > 0 && PopSucc); + Default = PopSucc; + for (CaseItr I = Cases.begin(); I != Cases.end();) { + if (I->BB == PopSucc) + I = Cases.erase(I); + else + ++I; + } + + // If there are no cases left, just branch. + if (Cases.empty()) { + BranchInst::Create(Default, CurBlock); + SI->eraseFromParent(); + return; + } + } + // Create a new, empty default block so that the new hierarchy of // if-then statements go to this and the PHI nodes are happy. - // if the default block is set as an unreachable we avoid creating one - // because will never be a valid target. - BasicBlock *NewDefault = nullptr; - if (!DefaultIsUnreachable) { - NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); - F->getBasicBlockList().insert(Default, NewDefault); + BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); + F->getBasicBlockList().insert(Default, NewDefault); + BranchInst::Create(Default, NewDefault); - BranchInst::Create(Default, NewDefault); - } // If there is an entry in any PHI nodes for the default edge, make sure // to update them as well. for (BasicBlock::iterator I = Default->begin(); isa(I); ++I) { @@ -409,40 +503,18 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI) { PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); } - // Prepare cases vector. - CaseVector Cases; - unsigned numCmps = Clusterify(Cases, SI); - - DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() - << ". Total compares: " << numCmps << "\n"); - DEBUG(dbgs() << "Cases: " << Cases << "\n"); - (void)numCmps; - - ConstantInt *UpperBound = nullptr; - ConstantInt *LowerBound = nullptr; - - // Optimize the condition where Default is an unreachable block. In this case - // we can make the bounds tightly fitted around the case value ranges, - // because we know that the value passed to the switch should always be - // exactly one of the case values. - if (DefaultIsUnreachable) { - CaseItr LastCase = Cases.begin() + Cases.size() - 1; - UpperBound = cast(LastCase->High); - LowerBound = cast(Cases.begin()->Low); - } BasicBlock *SwitchBlock = switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, - OrigBlock, OrigBlock, NewDefault); + OrigBlock, OrigBlock, NewDefault, UnreachableRanges); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); // We are now done with the switch instruction, delete it. + BasicBlock *OldDefault = SI->getDefaultDest(); CurBlock->getInstList().erase(SI); - pred_iterator PI = pred_begin(Default), E = pred_end(Default); - // If the Default block has no more predecessors just remove it - if (PI == E) { - DeleteDeadBlock(Default); - } + // If the Default block has no more predecessors just remove it. + if (pred_begin(OldDefault) == pred_end(OldDefault)) + DeleteDeadBlock(OldDefault); } diff --git a/test/Transforms/LowerSwitch/2014-06-11-SwitchDefaultUnreachableOpt.ll b/test/Transforms/LowerSwitch/2014-06-11-SwitchDefaultUnreachableOpt.ll index 0f737211f59..ecdd767bdea 100644 --- a/test/Transforms/LowerSwitch/2014-06-11-SwitchDefaultUnreachableOpt.ll +++ b/test/Transforms/LowerSwitch/2014-06-11-SwitchDefaultUnreachableOpt.ll @@ -1,5 +1,8 @@ ; RUN: opt < %s -lowerswitch -S | FileCheck %s -; CHECK-NOT: {{.*}}icmp eq{{.*}} +; +; The switch is lowered with a single icmp. +; CHECK: icmp +; CHECK-NOT: icmp ; ;int foo(int a) { ; @@ -14,7 +17,7 @@ ; ;} -define i32 @foo(i32 %a) nounwind ssp uwtable { +define i32 @foo(i32 %a) { %1 = alloca i32, align 4 %2 = alloca i32, align 4 store i32 %a, i32* %2, align 4 diff --git a/test/Transforms/LowerSwitch/fold-popular-case-to-unreachable-default.ll b/test/Transforms/LowerSwitch/fold-popular-case-to-unreachable-default.ll new file mode 100644 index 00000000000..54929c5bfca --- /dev/null +++ b/test/Transforms/LowerSwitch/fold-popular-case-to-unreachable-default.ll @@ -0,0 +1,110 @@ +; RUN: opt %s -lowerswitch -S | FileCheck %s + +define void @foo(i32 %x, i32* %p) { +; Cases 2 and 4 are removed and become the new default case. +; It is now enough to use two icmps to lower the switch. +; +; CHECK-LABEL: @foo +; CHECK: icmp slt i32 %x, 5 +; CHECK: icmp eq i32 %x, 1 +; CHECK-NOT: icmp +; +entry: + switch i32 %x, label %default [ + i32 1, label %bb0 + i32 2, label %popular + i32 4, label %popular + i32 5, label %bb1 + ] +bb0: + store i32 0, i32* %p + br label %exit +bb1: + store i32 1, i32* %p + br label %exit +popular: + store i32 2, i32* %p + br label %exit +exit: + ret void +default: + unreachable +} + +define void @unreachable_gap(i64 %x, i32* %p) { +; Cases 6 and INT64_MAX become the new default, but we still exploit the fact +; that 3-4 is unreachable, so four icmps is enough. + +; CHECK-LABEL: @unreachable_gap +; CHECK: icmp slt i64 %x, 2 +; CHECK: icmp slt i64 %x, 5 +; CHECK: icmp eq i64 %x, 5 +; CHECK: icmp slt i64 %x, 1 +; CHECK-NOT: icmp + +entry: + switch i64 %x, label %default [ + i64 -9223372036854775808, label %bb0 + i64 1, label %bb1 + i64 2, label %bb2 + i64 5, label %bb3 + i64 6, label %bb4 + i64 9223372036854775807, label %bb4 + ] +bb0: + store i32 0, i32* %p + br label %exit +bb1: + store i32 1, i32* %p + br label %exit +bb2: + store i32 2, i32* %p + br label %exit +bb3: + store i32 3, i32* %p + br label %exit +bb4: + store i32 4, i32* %p + br label %exit +exit: + ret void +default: + unreachable +} + + + +define void @nocases(i32 %x, i32* %p) { +; Don't fall over when there are no cases. +; +; CHECK-LABEL: @nocases +; CHECK-LABEL: entry +; CHECK-NEXT: br label %default +; +entry: + switch i32 %x, label %default [ + ] +default: + unreachable +} + +define void @nocasesleft(i32 %x, i32* %p) { +; Cases 2 and 4 are removed and we are left with no cases. +; +; CHECK-LABEL: @nocasesleft +; CHECK-LABEL: entry +; CHECK-NEXT: br label %popular +; +entry: + switch i32 %x, label %default [ + i32 2, label %popular + i32 4, label %popular + ] +popular: + store i32 2, i32* %p + br label %exit +exit: + ret void +default: + unreachable +}