diff --git a/lib/Transforms/Scalar/PredicateSimplifier.cpp b/lib/Transforms/Scalar/PredicateSimplifier.cpp index 6578bccff12..16cc20c06a8 100644 --- a/lib/Transforms/Scalar/PredicateSimplifier.cpp +++ b/lib/Transforms/Scalar/PredicateSimplifier.cpp @@ -41,11 +41,11 @@ // %a < %b < %c < %d // // with four nodes representing the properties. The InequalityGraph provides -// queries (such as "isEqual") and mutators (such as "addEqual"). To implement -// "isLess(%a, %c)", we start with getNode(%c) and walk downwards until -// we reach %a or the leaf node. Note that the graph is directed and acyclic, -// but may contain joins, meaning that this walk is not a linear time -// algorithm. +// querying with "isRelatedBy" and mutators "addEquality" and "addInequality". +// To find a relationship, we start with one of the nodes any binary search +// through its list to find where the relationships with the second node start. +// Then we iterate through those to find the first relationship that dominates +// our context node. // // To create these properties, we wait until a branch or switch instruction // implies that a particular value is true (or false). The VRPSolver is @@ -74,13 +74,13 @@ #include "llvm/DerivedTypes.h" #include "llvm/Instructions.h" #include "llvm/Pass.h" +#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/Dominators.h" #include "llvm/Analysis/ET-Forest.h" -#include "llvm/Assembly/Writer.h" #include "llvm/Support/CFG.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" @@ -89,14 +89,82 @@ #include #include #include -#include using namespace llvm; STATISTIC(NumVarsReplaced, "Number of argument substitutions"); STATISTIC(NumInstruction , "Number of instructions removed"); STATISTIC(NumSimple , "Number of simple replacements"); +STATISTIC(NumBlocks , "Number of blocks marked unreachable"); namespace { + // SLT SGT ULT UGT EQ + // 0 1 0 1 0 -- GT 10 + // 0 1 0 1 1 -- GE 11 + // 0 1 1 0 0 -- SGTULT 12 + // 0 1 1 0 1 -- SGEULE 13 + // 0 1 1 1 0 -- SGTUNE 14 + // 0 1 1 1 1 -- SGEUANY 15 + // 1 0 0 1 0 -- SLTUGT 18 + // 1 0 0 1 1 -- SLEUGE 19 + // 1 0 1 0 0 -- LT 20 + // 1 0 1 0 1 -- LE 21 + // 1 0 1 1 0 -- SLTUNE 22 + // 1 0 1 1 1 -- SLEUANY 23 + // 1 1 0 1 0 -- SNEUGT 26 + // 1 1 0 1 1 -- SANYUGE 27 + // 1 1 1 0 0 -- SNEULT 28 + // 1 1 1 0 1 -- SANYULE 29 + // 1 1 1 1 0 -- NE 30 + enum LatticeBits { + EQ_BIT = 1, UGT_BIT = 2, ULT_BIT = 4, SGT_BIT = 8, SLT_BIT = 16 + }; + enum LatticeVal { + GT = SGT_BIT | UGT_BIT, + GE = GT | EQ_BIT, + LT = SLT_BIT | ULT_BIT, + LE = LT | EQ_BIT, + NE = SLT_BIT | SGT_BIT | ULT_BIT | UGT_BIT, + SGTULT = SGT_BIT | ULT_BIT, + SGEULE = SGTULT | EQ_BIT, + SLTUGT = SLT_BIT | UGT_BIT, + SLEUGE = SLTUGT | EQ_BIT, + SNEULT = SLT_BIT | SGT_BIT | ULT_BIT, + SNEUGT = SLT_BIT | SGT_BIT | UGT_BIT, + SLTUNE = SLT_BIT | ULT_BIT | UGT_BIT, + SGTUNE = SGT_BIT | ULT_BIT | UGT_BIT, + SLEUANY = SLT_BIT | ULT_BIT | UGT_BIT | EQ_BIT, + SGEUANY = SGT_BIT | ULT_BIT | UGT_BIT | EQ_BIT, + SANYULE = SLT_BIT | SGT_BIT | ULT_BIT | EQ_BIT, + SANYUGE = SLT_BIT | SGT_BIT | UGT_BIT | EQ_BIT + }; + + static bool validPredicate(LatticeVal LV) { + switch (LV) { + case GT: case GE: case LT: case LE: case NE: + case SGTULT: case SGTUNE: case SGEULE: + case SLTUGT: case SLTUNE: case SLEUGE: + case SNEULT: case SNEUGT: + case SLEUANY: case SGEUANY: case SANYULE: case SANYUGE: + return true; + default: + return false; + } + } + + /// reversePredicate - reverse the direction of the inequality + static LatticeVal reversePredicate(LatticeVal LV) { + unsigned reverse = LV ^ (SLT_BIT|SGT_BIT|ULT_BIT|UGT_BIT); //preserve EQ_BIT + if ((reverse & (SLT_BIT|SGT_BIT)) == 0) + reverse |= (SLT_BIT|SGT_BIT); + + if ((reverse & (ULT_BIT|UGT_BIT)) == 0) + reverse |= (ULT_BIT|UGT_BIT); + + LatticeVal Rev = static_cast(reverse); + assert(validPredicate(Rev) && "Failed reversing predicate."); + return Rev; + } + /// The InequalityGraph stores the relationships between values. /// Each Value in the graph is assigned to a Node. Nodes are pointer /// comparable for equality. The caller is expected to maintain the logical @@ -105,38 +173,51 @@ namespace { /// The InequalityGraph class may invalidate Node*s after any mutator call. /// @brief The InequalityGraph stores the relationships between values. class VISIBILITY_HIDDEN InequalityGraph { + ETNode *TreeRoot; + + InequalityGraph(); // DO NOT IMPLEMENT + InequalityGraph(InequalityGraph &); // DO NOT IMPLEMENT public: + explicit InequalityGraph(ETNode *TreeRoot) : TreeRoot(TreeRoot) {} + class Node; - // LT GT EQ - // 0 0 0 -- invalid (false) - // 0 0 1 -- invalid (EQ) - // 0 1 0 -- GT - // 0 1 1 -- GE - // 1 0 0 -- LT - // 1 0 1 -- LE - // 1 1 0 -- NE - // 1 1 1 -- invalid (true) - enum LatticeBits { - EQ_BIT = 1, GT_BIT = 2, LT_BIT = 4 - }; - enum LatticeVal { - GT = GT_BIT, GE = GT_BIT | EQ_BIT, - LT = LT_BIT, LE = LT_BIT | EQ_BIT, - NE = GT_BIT | LT_BIT + /// This is a StrictWeakOrdering predicate that sorts ETNodes by how many + /// children they have. With this, you can iterate through a list sorted by + /// this operation and the first matching entry is the most specific match + /// for your basic block. The order provided is total; ETNodes with the + /// same number of children are sorted by pointer address. + struct VISIBILITY_HIDDEN OrderByDominance { + bool operator()(const ETNode *LHS, const ETNode *RHS) const { + unsigned LHS_spread = LHS->getDFSNumOut() - LHS->getDFSNumIn(); + unsigned RHS_spread = RHS->getDFSNumOut() - RHS->getDFSNumIn(); + if (LHS_spread != RHS_spread) return LHS_spread < RHS_spread; + else return LHS < RHS; + } }; - static bool validPredicate(LatticeVal LV) { - return LV > 1 && LV < 7; - } + /// An Edge is contained inside a Node making one end of the edge implicit + /// and contains a pointer to the other end. The edge contains a lattice + /// value specifying the relationship between the two nodes. Further, there + /// is an ETNode specifying which subtree of the dominator the edge applies. + class VISIBILITY_HIDDEN Edge { + public: + Edge(unsigned T, LatticeVal V, ETNode *ST) + : To(T), LV(V), Subtree(ST) {} - private: - typedef std::map NodeMapType; - NodeMapType Nodes; + unsigned To; + LatticeVal LV; + ETNode *Subtree; - const InequalityGraph *ConcreteIG; + bool operator<(const Edge &edge) const { + if (To != edge.To) return To < edge.To; + else return OrderByDominance()(Subtree, edge.Subtree); + } + bool operator<(unsigned to) const { + return To < to; + } + }; - public: /// A single node in the InequalityGraph. This stores the canonical Value /// for the node, as well as the relationships with the neighbours. /// @@ -148,367 +229,488 @@ namespace { class VISIBILITY_HIDDEN Node { friend class InequalityGraph; + typedef SmallVector RelationsType; + RelationsType Relations; + Value *Canonical; - typedef SmallVector, 4> RelationsType; - RelationsType Relations; + // TODO: can this idea improve performance? + //friend class std::vector; + //Node(Node &N) { RelationsType.swap(N.RelationsType); } + public: typedef RelationsType::iterator iterator; typedef RelationsType::const_iterator const_iterator; + Node(Value *V) : Canonical(V) {} + private: +#ifndef NDEBUG + public: + virtual void dump() const { + dump(*cerr.stream()); + } + private: + void dump(std::ostream &os) const { + os << *getValue() << ":\n"; + for (Node::const_iterator NI = begin(), NE = end(); NI != NE; ++NI) { + static const std::string names[32] = + { "000000", "000001", "000002", "000003", "000004", "000005", + "000006", "000007", "000008", "000009", " >", " >=", + " s>u<", "s>=u<=", " s>", " s>=", "000016", "000017", + " s", "s<=u>=", " <", " <=", " s<", " s<=", + "000024", "000025", " u>", " u>=", " u<", " u<=", + " !=", "000031" }; + os << " " << names[NI->LV] << " " << NI->To + << "(" << NI->Subtree << ")\n"; + } + } +#endif + + public: + iterator begin() { return Relations.begin(); } + iterator end() { return Relations.end(); } + const_iterator begin() const { return Relations.begin(); } + const_iterator end() const { return Relations.end(); } + + iterator find(unsigned n, ETNode *Subtree) { + iterator E = end(); + for (iterator I = std::lower_bound(begin(), E, n); + I != E && I->To == n; ++I) { + if (Subtree->DominatedBy(I->Subtree)) + return I; + } + return E; + } + + const_iterator find(unsigned n, ETNode *Subtree) const { + const_iterator E = end(); + for (const_iterator I = std::lower_bound(begin(), E, n); + I != E && I->To == n; ++I) { + if (Subtree->DominatedBy(I->Subtree)) + return I; + } + return E; + } + + Value *getValue() const + { + return Canonical; + } + /// Updates the lattice value for a given node. Create a new entry if /// one doesn't exist, otherwise it merges the values. The new lattice /// value must not be inconsistent with any previously existing value. - void update(Node *N, LatticeVal R) { - iterator I = find(N); + void update(unsigned n, LatticeVal R, ETNode *Subtree) { + assert(validPredicate(R) && "Invalid predicate."); + iterator I = find(n, Subtree); if (I == end()) { - Relations.push_back(std::make_pair(N, R)); + Edge edge(n, R, Subtree); + iterator Insert = std::lower_bound(begin(), end(), edge); + Relations.insert(Insert, edge); } else { - I->second = static_cast(I->second & R); - assert(validPredicate(I->second) && - "Invalid union of lattice values."); + LatticeVal LV = static_cast(I->LV & R); + assert(validPredicate(LV) && "Invalid union of lattice values."); + if (LV != I->LV) { + if (Subtree == I->Subtree) + I->LV = LV; + else { + assert(Subtree->DominatedBy(I->Subtree) && + "Find returned subtree that doesn't apply."); + + Edge edge(n, R, Subtree); + iterator Insert = std::lower_bound(begin(), end(), edge); + Relations.insert(Insert, edge); + } + } } } - - void assign(Node *N, LatticeVal R) { - iterator I = find(N); - if (I != end()) I->second = R; - - Relations.push_back(std::make_pair(N, R)); - } - - public: - iterator begin() { return Relations.begin(); } - iterator end() { return Relations.end(); } - iterator find(Node *N) { - iterator I = begin(); - for (iterator E = end(); I != E; ++I) - if (I->first == N) break; - return I; - } - - const_iterator begin() const { return Relations.begin(); } - const_iterator end() const { return Relations.end(); } - const_iterator find(Node *N) const { - const_iterator I = begin(); - for (const_iterator E = end(); I != E; ++I) - if (I->first == N) break; - return I; - } - - unsigned findIndex(Node *N) { - unsigned i = 0; - iterator I = begin(); - for (iterator E = end(); I != E; ++I, ++i) - if (I->first == N) return i; - return (unsigned)-1; - } - - void erase(iterator i) { Relations.erase(i); } - - Value *getValue() const { return Canonical; } - void setValue(Value *V) { Canonical = V; } - - void addNotEqual(Node *N) { update(N, NE); } - void addLess(Node *N) { update(N, LT); } - void addLessEqual(Node *N) { update(N, LE); } - void addGreater(Node *N) { update(N, GT); } - void addGreaterEqual(Node *N) { update(N, GE); } }; - InequalityGraph() : ConcreteIG(NULL) {} - - InequalityGraph(const InequalityGraph &_IG) { -#if 0 // disable COW - if (_IG.ConcreteIG) ConcreteIG = _IG.ConcreteIG; - else ConcreteIG = &_IG; -#else - ConcreteIG = &_IG; - materialize(); -#endif - } - - ~InequalityGraph(); - private: - void materialize(); + struct VISIBILITY_HIDDEN NodeMapEdge { + Value *V; + unsigned index; + ETNode *Subtree; + + NodeMapEdge(Value *V, unsigned index, ETNode *Subtree) + : V(V), index(index), Subtree(Subtree) {} + + bool operator==(const NodeMapEdge &RHS) const { + return V == RHS.V && + Subtree == RHS.Subtree; + } + + bool operator<(const NodeMapEdge &RHS) const { + if (V != RHS.V) return V < RHS.V; + return OrderByDominance()(Subtree, RHS.Subtree); + } + + bool operator<(Value *RHS) const { + return V < RHS; + } + }; + + typedef std::vector NodeMapType; + NodeMapType NodeMap; + + std::vector Nodes; + + std::vector > Constants; + void initializeConstant(Constant *C, unsigned index) { + ConstantIntegral *CI = dyn_cast(C); + if (!CI) return; + + // XXX: instead of O(n) calls to addInequality, just find the 2, 3 or 4 + // nodes that are nearest less than or greater than (signed or unsigned). + for (std::vector >::iterator + I = Constants.begin(), E = Constants.end(); I != E; ++I) { + ConstantIntegral *Other = I->first; + if (CI->getType() == Other->getType()) { + unsigned lv = 0; + + if (CI->getZExtValue() < Other->getZExtValue()) + lv |= ULT_BIT; + else + lv |= UGT_BIT; + + if (CI->getSExtValue() < Other->getSExtValue()) + lv |= SLT_BIT; + else + lv |= SGT_BIT; + + LatticeVal LV = static_cast(lv); + assert(validPredicate(LV) && "Not a valid predicate."); + if (!isRelatedBy(index, I->second, TreeRoot, LV)) + addInequality(index, I->second, TreeRoot, LV); + } + } + Constants.push_back(std::make_pair(CI, index)); + } public: - /// If the Value is in the graph, return the canonical form. Otherwise, - /// return the original Value. - Value *canonicalize(Value *V) const { - if (const Node *N = getNode(V)) - return N->getValue(); - else - return V; + /// node - returns the node object at a given index retrieved from getNode. + /// Index zero is reserved and may not be passed in here. The pointer + /// returned is valid until the next call to newNode or getOrInsertNode. + Node *node(unsigned index) { + assert(index != 0 && "Zero index is reserved for not found."); + assert(index <= Nodes.size() && "Index out of range."); + return &Nodes[index-1]; } - /// Returns the node currently representing Value V, or null if no such + /// Returns the node currently representing Value V, or zero if no such /// node exists. - Node *getNode(Value *V) { - materialize(); - - NodeMapType::const_iterator I = Nodes.find(V); - return (I != Nodes.end()) ? I->second : 0; + unsigned getNode(Value *V, ETNode *Subtree) { + NodeMapType::iterator E = NodeMap.end(); + NodeMapEdge Edge(V, 0, Subtree); + NodeMapType::iterator I = std::lower_bound(NodeMap.begin(), E, Edge); + while (I != E && I->V == V) { + if (Subtree->DominatedBy(I->Subtree)) + return I->index; + ++I; + } + return 0; } - const Node *getNode(Value *V) const { - if (ConcreteIG) return ConcreteIG->getNode(V); - - NodeMapType::const_iterator I = Nodes.find(V); - return (I != Nodes.end()) ? I->second : 0; - } - - Node *getOrInsertNode(Value *V) { - if (Node *N = getNode(V)) - return N; + /// getOrInsertNode - always returns a valid node index, creating a node + /// to match the Value if needed. + unsigned getOrInsertNode(Value *V, ETNode *Subtree) { + if (unsigned n = getNode(V, Subtree)) + return n; else return newNode(V); } - Node *newNode(Value *V) { - //DOUT << "new node: " << *V << "\n"; - materialize(); - Node *&N = Nodes[V]; - assert(N == 0 && "Node already exists for value."); - N = new Node(); - N->setValue(V); - return N; + /// newNode - creates a new node for a given Value and returns the index. + unsigned newNode(Value *V) { + Nodes.push_back(Node(V)); + + NodeMapEdge MapEntry = NodeMapEdge(V, Nodes.size(), TreeRoot); + assert(!std::binary_search(NodeMap.begin(), NodeMap.end(), MapEntry) && + "Attempt to create a duplicate Node."); + NodeMap.insert(std::lower_bound(NodeMap.begin(), NodeMap.end(), + MapEntry), MapEntry); + +#if 1 + // This is the missing piece to turn on VRP. + if (Constant *C = dyn_cast(V)) + initializeConstant(C, MapEntry.index); +#endif + + return MapEntry.index; } - /// Returns true iff the nodes are provably inequal. - bool isNotEqual(const Node *N1, const Node *N2) const { - if (N1 == N2) return false; - for (Node::const_iterator I = N1->begin(), E = N1->end(); I != E; ++I) { - if (I->first == N2) - return (I->second & EQ_BIT) == 0; - } - return isLess(N1, N2) || isGreater(N1, N2); + /// If the Value is in the graph, return the canonical form. Otherwise, + /// return the original Value. + Value *canonicalize(Value *V, ETNode *Subtree) { + if (isa(V)) return V; + + if (unsigned n = getNode(V, Subtree)) + return node(n)->getValue(); + else + return V; } - /// Returns true iff N1 is provably less than N2. - bool isLess(const Node *N1, const Node *N2) const { - if (N1 == N2) return false; - for (Node::const_iterator I = N2->begin(), E = N2->end(); I != E; ++I) { - if (I->first == N1) - return I->second == LT; - } - for (Node::const_iterator I = N2->begin(), E = N2->end(); I != E; ++I) { - if ((I->second & (LT_BIT | GT_BIT)) == LT_BIT) - if (isLess(N1, I->first)) return true; - } + /// isRelatedBy - true iff n1 op n2 + bool isRelatedBy(unsigned n1, unsigned n2, ETNode *Subtree, LatticeVal LV) { + if (n1 == n2) return LV & EQ_BIT; + + Node *N1 = node(n1); + Node::iterator I = N1->find(n2, Subtree), E = N1->end(); + if (I != E) return (I->LV & LV) == I->LV; + return false; } - /// Returns true iff N1 is provably less than or equal to N2. - bool isLessEqual(const Node *N1, const Node *N2) const { - if (N1 == N2) return true; - for (Node::const_iterator I = N2->begin(), E = N2->end(); I != E; ++I) { - if (I->first == N1) - return (I->second & (LT_BIT | GT_BIT)) == LT_BIT; - } - for (Node::const_iterator I = N2->begin(), E = N2->end(); I != E; ++I) { - if ((I->second & (LT_BIT | GT_BIT)) == LT_BIT) - if (isLessEqual(N1, I->first)) return true; - } - return false; - } - - /// Returns true iff N1 is provably greater than N2. - bool isGreater(const Node *N1, const Node *N2) const { - return isLess(N2, N1); - } - - /// Returns true iff N1 is provably greater than or equal to N2. - bool isGreaterEqual(const Node *N1, const Node *N2) const { - return isLessEqual(N2, N1); - } - // The add* methods assume that your input is logically valid and may // assertion-fail or infinitely loop if you attempt a contradiction. - void addEqual(Node *N, Value *V) { - materialize(); - Nodes[V] = N; - } + void addEquality(unsigned n, Value *V, ETNode *Subtree) { + assert(canonicalize(node(n)->getValue(), Subtree) == node(n)->getValue() + && "Node's 'canonical' choice isn't best within this subtree."); - void addNotEqual(Node *N1, Node *N2) { - assert(N1 != N2 && "A node can't be inequal to itself."); - materialize(); - N1->addNotEqual(N2); - N2->addNotEqual(N1); - } + // Suppose that we are given "%x -> node #1 (%y)". The problem is that + // we may already have "%z -> node #2 (%x)" somewhere above us in the + // graph. We need to find those edges and add "%z -> node #1 (%y)" + // to keep the lookups canonical. - /// N1 is less than N2. - void addLess(Node *N1, Node *N2) { - assert(N1 != N2 && !isLess(N2, N1) && "Attempt to create < cycle."); - materialize(); - N2->addLess(N1); - N1->addGreater(N2); - } + std::vector ToRepoint; + ToRepoint.push_back(V); - /// N1 is less than or equal to N2. - void addLessEqual(Node *N1, Node *N2) { - assert(N1 != N2 && "Nodes are equal. Use mergeNodes instead."); - assert(!isGreater(N1, N2) && "Impossible: Adding x <= y when x > y."); - materialize(); - N2->addLessEqual(N1); - N1->addGreaterEqual(N2); - } - - /// Find the transitive closure starting at a node walking down the edges - /// of type Val. Type Inserter must be an inserter that accepts Node *. - template - void transitiveClosure(Node *N, LatticeVal Val, Inserter insert) { - for (Node::iterator I = N->begin(), E = N->end(); I != E; ++I) { - if (I->second == Val) { - *insert = I->first; - transitiveClosure(I->first, Val, insert); + if (unsigned Conflict = getNode(V, Subtree)) { + // XXX: NodeMap.size() exceeds 68000 entries compiling kimwitu++! + // This adds 57 seconds to the otherwise 3 second build. Unacceptable. + // + // IDEA: could we iterate 1..Nodes.size() calling getNode? It's + // O(n log n) but kimwitu++ only has about 300 nodes. + for (NodeMapType::iterator I = NodeMap.begin(), E = NodeMap.end(); + I != E; ++I) { + if (I->index == Conflict && Subtree->DominatedBy(I->Subtree)) + ToRepoint.push_back(I->V); } } + + for (std::vector::iterator VI = ToRepoint.begin(), + VE = ToRepoint.end(); VI != VE; ++VI) { + Value *V = *VI; + + // XXX: review this code. This may be doing too many insertions. + NodeMapEdge Edge(V, n, Subtree); + NodeMapType::iterator E = NodeMap.end(); + NodeMapType::iterator I = std::lower_bound(NodeMap.begin(), E, Edge); + if (I == E || I->V != V || I->Subtree != Subtree) { + // New Value + NodeMap.insert(I, Edge); + } else if (I != E && I->V == V && I->Subtree == Subtree) { + // Update best choice + I->index = n; + } + +#ifndef NDEBUG + Node *N = node(n); + if (isa(V)) { + if (isa(N->getValue())) { + assert(V == N->getValue() && "Constant equals different constant?"); + } + } +#endif + } } - /// Kills off all the nodes in Kill by replicating their properties into - /// node N. The elements of Kill must be unique. After merging, N's new - /// canonical value is NewCanonical. Type C must be a container of Node *. - template - void mergeNodes(Node *N, C &Kill, Value *NewCanonical); + /// addInequality - Sets n1 op n2. + /// It is also an error to call this on an inequality that is already true. + void addInequality(unsigned n1, unsigned n2, ETNode *Subtree, + LatticeVal LV1) { + assert(n1 != n2 && "A node can't be inequal to itself."); + + if (LV1 != NE) + assert(!isRelatedBy(n1, n2, Subtree, reversePredicate(LV1)) && + "Contradictory inequality."); + + Node *N1 = node(n1); + Node *N2 = node(n2); + + // Suppose we're adding %n1 < %n2. Find all the %a < %n1 and + // add %a < %n2 too. This keeps the graph fully connected. + if (LV1 != NE) { + // Someone with a head for this sort of logic, please review this. + // Given that %x SLTUGT %y and %a SLEUANY %x, what is the relationship + // between %a and %y? I believe the below code is correct, but I don't + // think it's the most efficient solution. + + unsigned LV1_s = LV1 & (SLT_BIT|SGT_BIT); + unsigned LV1_u = LV1 & (ULT_BIT|UGT_BIT); + for (Node::iterator I = N1->begin(), E = N1->end(); I != E; ++I) { + if (I->LV != NE && I->To != n2) { + ETNode *Local_Subtree = NULL; + if (Subtree->DominatedBy(I->Subtree)) + Local_Subtree = Subtree; + else if (I->Subtree->DominatedBy(Subtree)) + Local_Subtree = I->Subtree; + + if (Local_Subtree) { + unsigned new_relationship = 0; + LatticeVal ILV = reversePredicate(I->LV); + unsigned ILV_s = ILV & (SLT_BIT|SGT_BIT); + unsigned ILV_u = ILV & (ULT_BIT|UGT_BIT); + + if (LV1_s != (SLT_BIT|SGT_BIT) && ILV_s == LV1_s) + new_relationship |= ILV_s; + + if (LV1_u != (ULT_BIT|UGT_BIT) && ILV_u == LV1_u) + new_relationship |= ILV_u; + + if (new_relationship) { + if ((new_relationship & (SLT_BIT|SGT_BIT)) == 0) + new_relationship |= (SLT_BIT|SGT_BIT); + if ((new_relationship & (ULT_BIT|UGT_BIT)) == 0) + new_relationship |= (ULT_BIT|UGT_BIT); + if ((LV1 & EQ_BIT) && (ILV & EQ_BIT)) + new_relationship |= EQ_BIT; + + LatticeVal NewLV = static_cast(new_relationship); + + node(I->To)->update(n2, NewLV, Local_Subtree); + N2->update(I->To, reversePredicate(NewLV), Local_Subtree); + } + } + } + } + + for (Node::iterator I = N2->begin(), E = N2->end(); I != E; ++I) { + if (I->LV != NE && I->To != n1) { + ETNode *Local_Subtree = NULL; + if (Subtree->DominatedBy(I->Subtree)) + Local_Subtree = Subtree; + else if (I->Subtree->DominatedBy(Subtree)) + Local_Subtree = I->Subtree; + + if (Local_Subtree) { + unsigned new_relationship = 0; + unsigned ILV_s = I->LV & (SLT_BIT|SGT_BIT); + unsigned ILV_u = I->LV & (ULT_BIT|UGT_BIT); + + if (LV1_s != (SLT_BIT|SGT_BIT) && ILV_s == LV1_s) + new_relationship |= ILV_s; + + if (LV1_u != (ULT_BIT|UGT_BIT) && ILV_u == LV1_u) + new_relationship |= ILV_u; + + if (new_relationship) { + if ((new_relationship & (SLT_BIT|SGT_BIT)) == 0) + new_relationship |= (SLT_BIT|SGT_BIT); + if ((new_relationship & (ULT_BIT|UGT_BIT)) == 0) + new_relationship |= (ULT_BIT|UGT_BIT); + if ((LV1 & EQ_BIT) && (I->LV & EQ_BIT)) + new_relationship |= EQ_BIT; + + LatticeVal NewLV = static_cast(new_relationship); + + N1->update(I->To, NewLV, Local_Subtree); + node(I->To)->update(n1, reversePredicate(NewLV), Local_Subtree); + } + } + } + } + } + + N1->update(n2, LV1, Subtree); + N2->update(n1, reversePredicate(LV1), Subtree); + } /// Removes a Value from the graph, but does not delete any nodes. As this /// method does not delete Nodes, V may not be the canonical choice for - /// any node. + /// a node with any relationships. It is invalid to call newNode on a Value + /// that has been removed. void remove(Value *V) { - materialize(); - - for (NodeMapType::iterator I = Nodes.begin(), E = Nodes.end(); I != E;) { - NodeMapType::iterator J = I++; - assert(J->second->getValue() != V && "Can't delete canonical choice."); - if (J->first == V) Nodes.erase(J); + for (unsigned i = 0; i < NodeMap.size();) { + NodeMapType::iterator I = NodeMap.begin()+i; + assert((node(I->index)->getValue() != V || node(I->index)->begin() == + node(I->index)->end()) && "Tried to delete in-use node."); + if (I->V == V) { +#ifndef NDEBUG + if (node(I->index)->getValue() == V) + node(I->index)->Canonical = NULL; +#endif + NodeMap.erase(I); + } else ++i; } } #ifndef NDEBUG - void debug(std::ostream &os) const { + virtual void dump() { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) { std::set VisitedNodes; - for (NodeMapType::const_iterator I = Nodes.begin(), E = Nodes.end(); + for (NodeMapType::const_iterator I = NodeMap.begin(), E = NodeMap.end(); I != E; ++I) { - Node *N = I->second; - os << *I->first << " == " << *N->getValue() << "\n"; + Node *N = node(I->index); + os << *I->V << " == " << I->index << "(" << I->Subtree << ")\n"; if (VisitedNodes.insert(N).second) { - os << *N->getValue() << ":\n"; - for (Node::const_iterator NI = N->begin(), NE = N->end(); - NI != NE; ++NI) { - static const std::string names[8] = - { "00", "01", " <", "<=", " >", ">=", "!=", "07" }; - os << " " << names[NI->second] << " " - << *NI->first->getValue() << "\n"; - } + os << I->index << ". "; + if (!N->getValue()) os << "(deleted node)\n"; + else N->dump(os); } } } #endif }; - InequalityGraph::~InequalityGraph() { - if (ConcreteIG) return; + /// UnreachableBlocks keeps tracks of blocks that are for one reason or + /// another discovered to be unreachable. This is used to cull the graph when + /// analyzing instructions, and to mark blocks with the "unreachable" + /// terminator instruction after the function has executed. + class VISIBILITY_HIDDEN UnreachableBlocks { + private: + std::vector DeadBlocks; - std::vector Remove; - for (NodeMapType::iterator I = Nodes.begin(), E = Nodes.end(); - I != E; ++I) { - if (I->first == I->second->getValue()) - Remove.push_back(I->second); + public: + /// mark - mark a block as dead + void mark(BasicBlock *BB) { + std::vector::iterator E = DeadBlocks.end(); + std::vector::iterator I = + std::lower_bound(DeadBlocks.begin(), E, BB); + + if (I == E || *I != BB) DeadBlocks.insert(I, BB); } - for (std::vector::iterator I = Remove.begin(), E = Remove.end(); - I != E; ++I) { - delete *I; + + /// isDead - returns whether a block is known to be dead already + bool isDead(BasicBlock *BB) { + std::vector::iterator E = DeadBlocks.end(); + std::vector::iterator I = + std::lower_bound(DeadBlocks.begin(), E, BB); + + return I != E && *I == BB; } - } - template - void InequalityGraph::mergeNodes(Node *N, C &Kill, Value *NewCanonical) { - materialize(); + /// kill - replace the dead blocks' terminator with an UnreachableInst. + bool kill() { + bool modified = false; + for (std::vector::iterator I = DeadBlocks.begin(), + E = DeadBlocks.end(); I != E; ++I) { + BasicBlock *BB = *I; - // Merge the relationships from the members of Kill into N. - for (typename C::iterator KI = Kill.begin(), KE = Kill.end(); - KI != KE; ++KI) { + DOUT << "unreachable block: " << BB->getName() << "\n"; - for (Node::iterator I = (*KI)->begin(), E = (*KI)->end(); I != E; ++I) { - if (I->first == N) continue; - - Node::iterator NI = N->find(I->first); - if (NI == N->end()) { - N->Relations.push_back(std::make_pair(I->first, I->second)); - } else { - unsigned char LV = NI->second & I->second; - if (LV == EQ_BIT) { - - assert(std::find(Kill.begin(), Kill.end(), I->first) != Kill.end() - && "Lost EQ property."); - N->erase(NI); - } else { - NI->second = static_cast(LV); - assert(InequalityGraph::validPredicate(NI->second) && - "Invalid union of lattice values."); - } + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); + SI != SE; ++SI) { + BasicBlock *Succ = *SI; + Succ->removePredecessor(BB); } - // All edges are reciprocal; every Node that Kill points to also - // contains a pointer to Kill. Replace those with pointers with N. - unsigned iter = I->first->findIndex(*KI); - assert(iter != (unsigned)-1 && "Edge not reciprocal."); - I->first->assign(N, (I->first->begin()+iter)->second); - I->first->erase(I->first->begin()+iter); - } - - // Removing references from N to Kill. - Node::iterator NI = N->find(*KI); - if (NI != N->end()) { - N->erase(NI); // breaks reciprocity until Kill is deleted. + TerminatorInst *TI = BB->getTerminator(); + TI->replaceAllUsesWith(UndefValue::get(TI->getType())); + TI->eraseFromParent(); + new UnreachableInst(BB); + ++NumBlocks; + modified = true; } + DeadBlocks.clear(); + return modified; } - - N->setValue(NewCanonical); - - // Update value mapping to point to the merged node. - for (NodeMapType::iterator I = Nodes.begin(), E = Nodes.end(); - I != E; ++I) { - if (std::find(Kill.begin(), Kill.end(), I->second) != Kill.end()) - I->second = N; - } - - for (typename C::iterator KI = Kill.begin(), KE = Kill.end(); - KI != KE; ++KI) { - delete *KI; - } - } - - void InequalityGraph::materialize() { - if (!ConcreteIG) return; - const InequalityGraph *IG = ConcreteIG; - ConcreteIG = NULL; - - for (NodeMapType::const_iterator I = IG->Nodes.begin(), - E = IG->Nodes.end(); I != E; ++I) { - if (I->first == I->second->getValue()) { - Node *N = newNode(I->first); - N->Relations.reserve(N->Relations.size()); - } - } - for (NodeMapType::const_iterator I = IG->Nodes.begin(), - E = IG->Nodes.end(); I != E; ++I) { - if (I->first != I->second->getValue()) { - Nodes[I->first] = getNode(I->second->getValue()); - } else { - Node *Old = I->second; - Node *N = getNode(I->first); - for (Node::const_iterator NI = Old->begin(), NE = Old->end(); - NI != NE; ++NI) { - N->assign(getNode(NI->first->getValue()), NI->second); - } - } - } - } + }; /// VRPSolver keeps track of how changes to one variable affect other /// variables, and forwards changes along to the InequalityGraph. It @@ -516,15 +718,46 @@ namespace { /// @brief VRPSolver calculates inferences from a new relationship. class VISIBILITY_HIDDEN VRPSolver { private: - std::deque WorkList; + struct Operation { + Value *LHS, *RHS; + ICmpInst::Predicate Op; + + Instruction *Context; + }; + std::deque WorkList; InequalityGraph &IG; - const InequalityGraph &cIG; + UnreachableBlocks &UB; ETForest *Forest; ETNode *Top; + BasicBlock *TopBB; + Instruction *TopInst; + bool &modified; typedef InequalityGraph::Node Node; + /// IdomI - Determines whether one Instruction dominates another. + bool IdomI(Instruction *I1, Instruction *I2) const { + BasicBlock *BB1 = I1->getParent(), + *BB2 = I2->getParent(); + if (BB1 == BB2) { + if (isa(I1)) return false; + if (isa(I2)) return true; + if (isa(I1) && !isa(I2)) return true; + if (!isa(I1) && isa(I2)) return false; + + for (BasicBlock::const_iterator I = BB1->begin(), E = BB1->end(); + I != E; ++I) { + if (&*I == I1) return true; + if (&*I == I2) return false; + } + assert(!"Instructions not found in parent BasicBlock?"); + } else { + return Forest->properlyDominates(BB1, BB2); + } + return false; + } + /// Returns true if V1 is a better canonical value than V2. bool compare(Value *V1, Value *V2) const { if (isa(V1)) @@ -539,664 +772,614 @@ namespace { Instruction *I1 = dyn_cast(V1); Instruction *I2 = dyn_cast(V2); - if (!I1 || !I2) return false; + if (!I1 || !I2) + return V1->getNumUses() < V2->getNumUses(); - BasicBlock *BB1 = I1->getParent(), - *BB2 = I2->getParent(); - if (BB1 == BB2) { - for (BasicBlock::const_iterator I = BB1->begin(), E = BB1->end(); - I != E; ++I) { - if (&*I == I1) return true; - if (&*I == I2) return false; - } - assert(!"Instructions not found in parent BasicBlock?"); - } else { - return Forest->properlyDominates(BB1, BB2); - } - return false; + return IdomI(I1, I2); } - void addToWorklist(Instruction *I) { - //DOUT << "addToWorklist: " << *I << "\n"; - - if (!isa(I) && !isa(I) && !isa(I)) - return; - - const Type *Ty = I->getType(); - if (Ty == Type::VoidTy || Ty->isFPOrFPVector()) return; - - if (isInstructionTriviallyDead(I)) return; - - WorkList.push_back(I); + // below - true if the Instruction is dominated by the current context + // block or instruction + bool below(Instruction *I) { + if (TopInst) + return IdomI(TopInst, I); + else { + ETNode *Node = Forest->getNodeForBlock(I->getParent()); + return Node == Top || Node->DominatedBy(Top); + } } - void addRecursive(Value *V) { - //DOUT << "addRecursive: " << *V << "\n"; + bool makeEqual(Value *V1, Value *V2) { + DOUT << "makeEqual(" << *V1 << ", " << *V2 << ")\n"; - Instruction *I = dyn_cast(V); - if (I) - addToWorklist(I); - else if (!isa(V)) - return; + if (V1 == V2) return true; - //DOUT << "addRecursive uses...\n"; - for (Value::use_iterator UI = V->use_begin(), UE = V->use_end(); - UI != UE; ++UI) { - // Use must be either be dominated by Top, or dominate Top. - if (Instruction *Inst = dyn_cast(*UI)) { - ETNode *INode = Forest->getNodeForBlock(Inst->getParent()); - if (INode->DominatedBy(Top) || Top->DominatedBy(INode)) - addToWorklist(Inst); + if (isa(V1) && isa(V2)) + return false; + + unsigned n1 = IG.getNode(V1, Top), n2 = IG.getNode(V2, Top); + + if (n1 && n2) { + if (n1 == n2) return true; + if (IG.isRelatedBy(n1, n2, Top, NE)) return false; + } + + if (n1) assert(V1 == IG.node(n1)->getValue() && "Value isn't canonical."); + if (n2) assert(V2 == IG.node(n2)->getValue() && "Value isn't canonical."); + + if (compare(V2, V1)) { std::swap(V1, V2); std::swap(n1, n2); } + + assert(!isa(V2) && "Tried to remove a constant."); + + SetVector Remove; + if (n2) Remove.insert(n2); + + if (n1 && n2) { + // Suppose we're being told that %x == %y, and %x <= %z and %y >= %z. + // We can't just merge %x and %y because the relationship with %z would + // be EQ and that's invalid. What we're doing is looking for any nodes + // %z such that %x <= %z and %y >= %z, and vice versa. + // + // Also handle %a <= %b and %c <= %a when adding %b <= %c. + + Node *N1 = IG.node(n1); + Node::iterator end = N1->end(); + for (unsigned i = 0; i < Remove.size(); ++i) { + Node *N = IG.node(Remove[i]); + Value *V = N->getValue(); + for (Node::iterator I = N->begin(), E = N->end(); I != E; ++I) { + if (I->LV & EQ_BIT) { + if (Top == I->Subtree || Top->DominatedBy(I->Subtree)) { + Node::iterator NI = N1->find(I->To, Top); + if (NI != end) { + if (!(NI->LV & EQ_BIT)) return false; + if (isRelatedBy(V, IG.node(NI->To)->getValue(), + ICmpInst::ICMP_NE)) + return false; + Remove.insert(NI->To); + } + } + } + } + } + + // See if one of the nodes about to be removed is actually a better + // canonical choice than n1. + unsigned orig_n1 = n1; + std::vector::iterator DontRemove = Remove.end(); + for (std::vector::iterator I = Remove.begin()+1 /* skip n2 */, + E = Remove.end(); I != E; ++I) { + unsigned n = *I; + Value *V = IG.node(n)->getValue(); + if (compare(V, V1)) { + V1 = V; + n1 = n; + DontRemove = I; + } + } + if (DontRemove != Remove.end()) { + unsigned n = *DontRemove; + Remove.remove(n); + Remove.insert(orig_n1); } } - if (I) { - //DOUT << "addRecursive ops...\n"; - for (User::op_iterator OI = I->op_begin(), OE = I->op_end(); - OI != OE; ++OI) { - if (Instruction *Inst = dyn_cast(*OI)) - addToWorklist(Inst); + // We'd like to allow makeEqual on two values to perform a simple + // substitution without every creating nodes in the IG whenever possible. + // + // The first iteration through this loop operates on V2 before going + // through the Remove list and operating on those too. If all of the + // iterations performed simple replacements then we exit early. + bool exitEarly = true; + unsigned i = 0; + for (Value *R = V2; i == 0 || i < Remove.size(); ++i) { + if (i) R = IG.node(Remove[i])->getValue(); // skip n2. + + // Try to replace the whole instruction. If we can, we're done. + Instruction *I2 = dyn_cast(R); + if (I2 && below(I2)) { + std::vector ToNotify; + for (Value::use_iterator UI = R->use_begin(), UE = R->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) + ToNotify.push_back(I); + } + + DOUT << "Simply removing " << *I2 + << ", replacing with " << *V1 << "\n"; + I2->replaceAllUsesWith(V1); + // leave it dead; it'll get erased later. + ++NumInstruction; + modified = true; + + for (std::vector::iterator II = ToNotify.begin(), + IE = ToNotify.end(); II != IE; ++II) { + opsToDef(*II); + } + + continue; + } + + // Otherwise, replace all dominated uses. + for (Value::use_iterator UI = R->use_begin(), UE = R->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) { + if (below(I)) { + TheUse.set(V1); + modified = true; + ++NumVarsReplaced; + opsToDef(I); + } + } + } + + // If that killed the instruction, stop here. + if (I2 && isInstructionTriviallyDead(I2)) { + DOUT << "Killed all uses of " << *I2 + << ", replacing with " << *V1 << "\n"; + continue; + } + + // If we make it to here, then we will need to create a node for N1. + // Otherwise, we can skip out early! + exitEarly = false; + } + + if (exitEarly) return true; + + // Create N1. + // XXX: this should call newNode, but instead the node might be created + // in isRelatedBy. That's also a fixme. + if (!n1) n1 = IG.getOrInsertNode(V1, Top); + + // Migrate relationships from removed nodes to N1. + Node *N1 = IG.node(n1); + for (std::vector::iterator I = Remove.begin(), E = Remove.end(); + I != E; ++I) { + unsigned n = *I; + Node *N = IG.node(n); + for (Node::iterator NI = N->begin(), NE = N->end(); NI != NE; ++NI) { + if (Top == NI->Subtree || NI->Subtree->DominatedBy(Top)) { + if (NI->To == n1) { + assert((NI->LV & EQ_BIT) && "Node inequal to itself."); + continue; + } + if (Remove.count(NI->To)) + continue; + + IG.node(NI->To)->update(n1, reversePredicate(NI->LV), Top); + N1->update(NI->To, NI->LV, Top); + } } } - //DOUT << "exit addRecursive (" << *V << ").\n"; + + // Point V2 (and all items in Remove) to N1. + if (!n2) + IG.addEquality(n1, V2, Top); + else { + for (std::vector::iterator I = Remove.begin(), + E = Remove.end(); I != E; ++I) { + IG.addEquality(n1, IG.node(*I)->getValue(), Top); + } + } + + // If !Remove.empty() then V2 = Remove[0]->getValue(). + // Even when Remove is empty, we still want to process V2. + i = 0; + for (Value *R = V2; i == 0 || i < Remove.size(); ++i) { + if (i) R = IG.node(Remove[i])->getValue(); // skip n2. + + if (Instruction *I2 = dyn_cast(R)) defToOps(I2); + for (Value::use_iterator UI = V2->use_begin(), UE = V2->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) { + opsToDef(I); + } + } + } + + return true; + } + + /// cmpInstToLattice - converts an CmpInst::Predicate to lattice value + /// Requires that the lattice value be valid; does not accept ICMP_EQ. + static LatticeVal cmpInstToLattice(ICmpInst::Predicate Pred) { + switch (Pred) { + case ICmpInst::ICMP_EQ: + assert(!"No matching lattice value."); + return static_cast(EQ_BIT); + default: + assert(!"Invalid 'icmp' predicate."); + case ICmpInst::ICMP_NE: + return NE; + case ICmpInst::ICMP_UGT: + return SNEUGT; + case ICmpInst::ICMP_UGE: + return SANYUGE; + case ICmpInst::ICMP_ULT: + return SNEULT; + case ICmpInst::ICMP_ULE: + return SANYULE; + case ICmpInst::ICMP_SGT: + return SGTUNE; + case ICmpInst::ICMP_SGE: + return SGEUANY; + case ICmpInst::ICMP_SLT: + return SLTUNE; + case ICmpInst::ICMP_SLE: + return SLEUANY; + } } public: - VRPSolver(InequalityGraph &IG, ETForest *Forest, BasicBlock *TopBB) - : IG(IG), cIG(IG), Forest(Forest), Top(Forest->getNodeForBlock(TopBB)) {} + VRPSolver(InequalityGraph &IG, UnreachableBlocks &UB, ETForest *Forest, + bool &modified, BasicBlock *TopBB) + : IG(IG), + UB(UB), + Forest(Forest), + Top(Forest->getNodeForBlock(TopBB)), + TopBB(TopBB), + TopInst(NULL), + modified(modified) {} + + VRPSolver(InequalityGraph &IG, UnreachableBlocks &UB, ETForest *Forest, + bool &modified, Instruction *TopInst) + : IG(IG), + UB(UB), + Forest(Forest), + TopInst(TopInst), + modified(modified) + { + TopBB = TopInst->getParent(); + Top = Forest->getNodeForBlock(TopBB); + } + + bool isRelatedBy(Value *V1, Value *V2, ICmpInst::Predicate Pred) const { + if (Constant *C1 = dyn_cast(V1)) + if (Constant *C2 = dyn_cast(V2)) + return ConstantExpr::getCompare(Pred, C1, C2) == + ConstantBool::getTrue(); + + // XXX: this is lousy. If we're passed a Constant, then we might miss + // some relationships if it isn't in the IG because the relationships + // added by initializeConstant are missing. + if (isa(V1)) IG.getOrInsertNode(V1, Top); + if (isa(V2)) IG.getOrInsertNode(V2, Top); + + if (unsigned n1 = IG.getNode(V1, Top)) + if (unsigned n2 = IG.getNode(V2, Top)) { + if (n1 == n2) return Pred == ICmpInst::ICMP_EQ || + Pred == ICmpInst::ICMP_ULE || + Pred == ICmpInst::ICMP_UGE || + Pred == ICmpInst::ICMP_SLE || + Pred == ICmpInst::ICMP_SGE; + if (Pred == ICmpInst::ICMP_EQ) return false; + return IG.isRelatedBy(n1, n2, Top, cmpInstToLattice(Pred)); + } - bool isEqual(Value *V1, Value *V2) const { - if (V1 == V2) return true; - if (const Node *N1 = cIG.getNode(V1)) - return N1 == cIG.getNode(V2); return false; } - bool isNotEqual(Value *V1, Value *V2) const { - if (V1 == V2) return false; - if (const Node *N1 = cIG.getNode(V1)) - if (const Node *N2 = cIG.getNode(V2)) - return cIG.isNotEqual(N1, N2); - return false; + /// add - adds a new property to the work queue + void add(Value *V1, Value *V2, ICmpInst::Predicate Pred, + Instruction *I = NULL) { + DOUT << "adding " << *V1 << " " << Pred << " " << *V2; + if (I) DOUT << " context: " << *I; + else DOUT << " default context"; + DOUT << "\n"; + + WorkList.push_back(Operation()); + Operation &O = WorkList.back(); + O.LHS = V1, O.RHS = V2, O.Op = Pred, O.Context = I; } - bool isLess(Value *V1, Value *V2) const { - if (V1 == V2) return false; - if (const Node *N1 = cIG.getNode(V1)) - if (const Node *N2 = cIG.getNode(V2)) - return cIG.isLess(N1, N2); - return false; - } + /// defToOps - Given an instruction definition that we've learned something + /// new about, find any new relationships between its operands. + void defToOps(Instruction *I) { + Instruction *NewContext = below(I) ? I : TopInst; + Value *Canonical = IG.canonicalize(I, Top); - bool isLessEqual(Value *V1, Value *V2) const { - if (V1 == V2) return true; - if (const Node *N1 = cIG.getNode(V1)) - if (const Node *N2 = cIG.getNode(V2)) - return cIG.isLessEqual(N1, N2); - return false; - } + if (BinaryOperator *BO = dyn_cast(I)) { + const Type *Ty = BO->getType(); + assert(!Ty->isFPOrFPVector() && "Float in work queue!"); - bool isGreater(Value *V1, Value *V2) const { - if (V1 == V2) return false; - if (const Node *N1 = cIG.getNode(V1)) - if (const Node *N2 = cIG.getNode(V2)) - return cIG.isGreater(N1, N2); - return false; - } + Value *Op0 = IG.canonicalize(BO->getOperand(0), Top); + Value *Op1 = IG.canonicalize(BO->getOperand(1), Top); - bool isGreaterEqual(Value *V1, Value *V2) const { - if (V1 == V2) return true; - if (const Node *N1 = IG.getNode(V1)) - if (const Node *N2 = IG.getNode(V2)) - return cIG.isGreaterEqual(N1, N2); - return false; - } + // TODO: "and bool true, %x" EQ %y then %x EQ %y. - // All of the add* functions return true if the InequalityGraph represents - // the property, and false if there is a logical contradiction. On false, - // you may no longer perform any queries on the InequalityGraph. - - bool addEqual(Value *V1, Value *V2) { - //DOUT << "addEqual(" << *V1 << ", " << *V2 << ")\n"; - if (isEqual(V1, V2)) return true; - - const Node *cN1 = cIG.getNode(V1), *cN2 = cIG.getNode(V2); - - if (cN1 && cN2 && cIG.isNotEqual(cN1, cN2)) - return false; - - if (compare(V2, V1)) { std::swap(V1, V2); std::swap(cN1, cN2); } - - if (cN1) { - if (ConstantBool *CB = dyn_cast(V1)) { - Node *N1 = IG.getNode(V1); - - // When "addEqual" is performed and the new value is a ConstantBool, - // iterate through the NE set and fix them up to be EQ of the - // opposite bool. - - for (Node::iterator I = N1->begin(), E = N1->end(); I != E; ++I) - if ((I->second & 1) == 0) { - assert(N1 != I->first && "Node related to itself?"); - addEqual(I->first->getValue(), - ConstantBool::get(!CB->getValue())); + switch (BO->getOpcode()) { + case Instruction::And: { + // "and int %a, %b" EQ -1 then %a EQ -1 and %b EQ -1 + // "and bool %a, %b" EQ true then %a EQ true and %b EQ true + ConstantIntegral *CI = ConstantIntegral::getAllOnesValue(Ty); + if (Canonical == CI) { + add(CI, Op0, ICmpInst::ICMP_EQ, NewContext); + add(CI, Op1, ICmpInst::ICMP_EQ, NewContext); } + } break; + case Instruction::Or: { + // "or int %a, %b" EQ 0 then %a EQ 0 and %b EQ 0 + // "or bool %a, %b" EQ false then %a EQ false and %b EQ false + Constant *Zero = Constant::getNullValue(Ty); + if (Canonical == Zero) { + add(Zero, Op0, ICmpInst::ICMP_EQ, NewContext); + add(Zero, Op1, ICmpInst::ICMP_EQ, NewContext); + } + } break; + case Instruction::Xor: { + // "xor bool true, %a" EQ true then %a EQ false + // "xor bool true, %a" EQ false then %a EQ true + // "xor bool false, %a" EQ true then %a EQ true + // "xor bool false, %a" EQ false then %a EQ false + // "xor int %c, %a" EQ %c then %a EQ 0 + // "xor int %c, %a" NE %c then %a NE 0 + // 1. Repeat all of the above, with order of operands reversed. + Value *LHS = Op0; + Value *RHS = Op1; + if (!isa(LHS)) std::swap(LHS, RHS); + + if (ConstantBool *CB = dyn_cast(Canonical)) { + if (ConstantBool *A = dyn_cast(LHS)) + add(RHS, ConstantBool::get(A->getValue() ^ CB->getValue()), + ICmpInst::ICMP_EQ, NewContext); + } + if (Canonical == LHS) { + if (isa(Canonical)) + add(RHS, Constant::getNullValue(Ty), ICmpInst::ICMP_EQ, + NewContext); + } else if (isRelatedBy(LHS, Canonical, ICmpInst::ICMP_NE)) { + add(RHS, Constant::getNullValue(Ty), ICmpInst::ICMP_NE, + NewContext); + } + } break; + default: + break; + } + } else if (ICmpInst *IC = dyn_cast(I)) { + // "icmp ult int %a, int %y" EQ true then %a u< y + // etc. + + if (Canonical == ConstantBool::getTrue()) { + add(IC->getOperand(0), IC->getOperand(1), IC->getPredicate(), + NewContext); + } else if (Canonical == ConstantBool::getFalse()) { + add(IC->getOperand(0), IC->getOperand(1), + ICmpInst::getInversePredicate(IC->getPredicate()), NewContext); + } + } else if (SelectInst *SI = dyn_cast(I)) { + if (I->getType()->isFPOrFPVector()) return; + + // Given: "%a = select bool %x, int %b, int %c" + // %a EQ %b and %b NE %c then %x EQ true + // %a EQ %c and %b NE %c then %x EQ false + + Value *True = SI->getTrueValue(); + Value *False = SI->getFalseValue(); + if (isRelatedBy(True, False, ICmpInst::ICMP_NE)) { + if (Canonical == IG.canonicalize(True, Top) || + isRelatedBy(Canonical, False, ICmpInst::ICMP_NE)) + add(SI->getCondition(), ConstantBool::getTrue(), + ICmpInst::ICMP_EQ, NewContext); + else if (Canonical == IG.canonicalize(False, Top) || + isRelatedBy(I, True, ICmpInst::ICMP_NE)) + add(SI->getCondition(), ConstantBool::getFalse(), + ICmpInst::ICMP_EQ, NewContext); } } + // TODO: CastInst "%a = cast ... %b" where %a is EQ or NE a constant. + } - if (!cN2) { - if (Instruction *I2 = dyn_cast(V2)) { - ETNode *Node_I2 = Forest->getNodeForBlock(I2->getParent()); - if (Top != Node_I2 && Node_I2->DominatedBy(Top)) { - Value *V = V1; - if (cN1 && compare(V1, cN1->getValue())) V = cN1->getValue(); - //DOUT << "Simply removing " << *I2 - // << ", replacing with " << *V << "\n"; - I2->replaceAllUsesWith(V); - // leave it dead; it'll get erased later. - ++NumSimple; - addRecursive(V1); - return true; + /// opsToDef - A new relationship was discovered involving one of this + /// instruction's operands. Find any new relationship involving the + /// definition. + void opsToDef(Instruction *I) { + Instruction *NewContext = below(I) ? I : TopInst; + + if (BinaryOperator *BO = dyn_cast(I)) { + Value *Op0 = IG.canonicalize(BO->getOperand(0), Top); + Value *Op1 = IG.canonicalize(BO->getOperand(1), Top); + + if (ConstantIntegral *CI0 = dyn_cast(Op0)) + if (ConstantIntegral *CI1 = dyn_cast(Op1)) { + add(BO, ConstantExpr::get(BO->getOpcode(), CI0, CI1), + ICmpInst::ICMP_EQ, NewContext); + return; + } + + // "%y = and bool true, %x" then %x EQ %y. + // "%y = or bool false, %x" then %x EQ %y. + if (BO->getOpcode() == Instruction::Or) { + Constant *Zero = Constant::getNullValue(BO->getType()); + if (Op0 == Zero) { + add(BO, Op1, ICmpInst::ICMP_EQ, NewContext); + return; + } else if (Op1 == Zero) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + } else if (BO->getOpcode() == Instruction::And) { + Constant *AllOnes = ConstantIntegral::getAllOnesValue(BO->getType()); + if (Op0 == AllOnes) { + add(BO, Op1, ICmpInst::ICMP_EQ, NewContext); + return; + } else if (Op1 == AllOnes) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; } } - } - Node *N1 = IG.getNode(V1), *N2 = IG.getNode(V2); + // "%x = add int %y, %z" and %x EQ %y then %z EQ 0 + // "%x = mul int %y, %z" and %x EQ %y then %z EQ 1 + // 1. Repeat all of the above, with order of operands reversed. + // "%x = udiv int %y, %z" and %x EQ %y then %z EQ 1 - if ( N1 && !N2) { - IG.addEqual(N1, V2); - if (compare(V1, N1->getValue())) N1->setValue(V1); - } - if (!N1 && N2) { - IG.addEqual(N2, V1); - if (compare(V1, N2->getValue())) N2->setValue(V1); - } - if ( N1 && N2) { - // Suppose we're being told that %x == %y, and %x <= %z and %y >= %z. - // We can't just merge %x and %y because the relationship with %z would - // be EQ and that's invalid; they need to be the same Node. - // - // What we're doing is looking for any chain of nodes reaching %z such - // that %x <= %z and %y >= %z, and vice versa. The cool part is that - // every node in between is also equal because of the squeeze principle. - - std::vector N1_GE, N2_LE, N1_LE, N2_GE; - IG.transitiveClosure(N1, InequalityGraph::GE, back_inserter(N1_GE)); - std::sort(N1_GE.begin(), N1_GE.end()); - N1_GE.erase(std::unique(N1_GE.begin(), N1_GE.end()), N1_GE.end()); - IG.transitiveClosure(N2, InequalityGraph::LE, back_inserter(N2_LE)); - std::sort(N1_LE.begin(), N1_LE.end()); - N1_LE.erase(std::unique(N1_LE.begin(), N1_LE.end()), N1_LE.end()); - IG.transitiveClosure(N1, InequalityGraph::LE, back_inserter(N1_LE)); - std::sort(N2_GE.begin(), N2_GE.end()); - N2_GE.erase(std::unique(N2_GE.begin(), N2_GE.end()), N2_GE.end()); - std::unique(N2_GE.begin(), N2_GE.end()); - IG.transitiveClosure(N2, InequalityGraph::GE, back_inserter(N2_GE)); - std::sort(N2_LE.begin(), N2_LE.end()); - N2_LE.erase(std::unique(N2_LE.begin(), N2_LE.end()), N2_LE.end()); - - std::vector Set1, Set2; - std::set_intersection(N1_GE.begin(), N1_GE.end(), - N2_LE.begin(), N2_LE.end(), - back_inserter(Set1)); - std::set_intersection(N1_LE.begin(), N1_LE.end(), - N2_GE.begin(), N2_GE.end(), - back_inserter(Set2)); - - std::vector Equal; - std::set_union(Set1.begin(), Set1.end(), Set2.begin(), Set2.end(), - back_inserter(Equal)); - - Value *Best = N1->getValue(); - if (compare(N2->getValue(), Best)) Best = N2->getValue(); - - for (std::vector::iterator I = Equal.begin(), E = Equal.end(); - I != E; ++I) { - Value *V = (*I)->getValue(); - if (compare(V, Best)) Best = V; - } - - Equal.push_back(N2); - IG.mergeNodes(N1, Equal, Best); - } - if (!N1 && !N2) IG.addEqual(IG.newNode(V1), V2); - - addRecursive(V1); - addRecursive(V2); - - return true; - } - - bool addNotEqual(Value *V1, Value *V2) { - //DOUT << "addNotEqual(" << *V1 << ", " << *V2 << ")\n"); - if (isNotEqual(V1, V2)) return true; - - // Never permit %x NE true/false. - if (ConstantBool *B1 = dyn_cast(V1)) { - return addEqual(ConstantBool::get(!B1->getValue()), V2); - } else if (ConstantBool *B2 = dyn_cast(V2)) { - return addEqual(V1, ConstantBool::get(!B2->getValue())); - } - - Node *N1 = IG.getOrInsertNode(V1), - *N2 = IG.getOrInsertNode(V2); - - if (N1 == N2) return false; - - IG.addNotEqual(N1, N2); - - addRecursive(V1); - addRecursive(V2); - - return true; - } - - /// Set V1 less than V2. - bool addLess(Value *V1, Value *V2) { - if (isLess(V1, V2)) return true; - if (isGreaterEqual(V1, V2)) return false; - - Node *N1 = IG.getOrInsertNode(V1), *N2 = IG.getOrInsertNode(V2); - - if (N1 == N2) return false; - - IG.addLess(N1, N2); - - addRecursive(V1); - addRecursive(V2); - - return true; - } - - /// Set V1 less than or equal to V2. - bool addLessEqual(Value *V1, Value *V2) { - if (isLessEqual(V1, V2)) return true; - if (V1 == V2) return true; - - if (isLessEqual(V2, V1)) - return addEqual(V1, V2); - - if (isGreater(V1, V2)) return false; - - Node *N1 = IG.getOrInsertNode(V1), - *N2 = IG.getOrInsertNode(V2); - - if (N1 == N2) return true; - - IG.addLessEqual(N1, N2); - - addRecursive(V1); - addRecursive(V2); - - return true; - } - - void solve() { - DOUT << "WorkList entry, size: " << WorkList.size() << "\n"; - while (!WorkList.empty()) { - DOUT << "WorkList size: " << WorkList.size() << "\n"; - - Instruction *I = WorkList.front(); - WorkList.pop_front(); - - Value *Canonical = cIG.canonicalize(I); - const Type *Ty = I->getType(); - - //DOUT << "solving: " << *I << "\n"; - //DEBUG(IG.debug(*cerr.stream())); - - if (BinaryOperator *BO = dyn_cast(I)) { - Value *Op0 = cIG.canonicalize(BO->getOperand(0)), - *Op1 = cIG.canonicalize(BO->getOperand(1)); - - ConstantIntegral *CI1 = dyn_cast(Op0), - *CI2 = dyn_cast(Op1); - - if (CI1 && CI2) - addEqual(BO, ConstantExpr::get(BO->getOpcode(), CI1, CI2)); + Value *Known = Op0, *Unknown = Op1; + if (Known != BO) std::swap(Known, Unknown); + if (Known == BO) { + const Type *Ty = BO->getType(); + assert(!Ty->isFPOrFPVector() && "Float in work queue!"); switch (BO->getOpcode()) { - case Instruction::And: { - // "and int %a, %b" EQ -1 then %a EQ -1 and %b EQ -1 - // "and bool %a, %b" EQ true then %a EQ true and %b EQ true - ConstantIntegral *CI = ConstantIntegral::getAllOnesValue(Ty); - if (Canonical == CI) { - addEqual(CI, Op0); - addEqual(CI, Op1); - } - } break; - case Instruction::Or: { - // "or int %a, %b" EQ 0 then %a EQ 0 and %b EQ 0 - // "or bool %a, %b" EQ false then %a EQ false and %b EQ false - Constant *Zero = Constant::getNullValue(Ty); - if (Canonical == Zero) { - addEqual(Zero, Op0); - addEqual(Zero, Op1); - } - } break; - case Instruction::Xor: { - // "xor bool true, %a" EQ true then %a EQ false - // "xor bool true, %a" EQ false then %a EQ true - // "xor bool false, %a" EQ true then %a EQ true - // "xor bool false, %a" EQ false then %a EQ false - // "xor int %c, %a" EQ %c then %a EQ 0 - // "xor int %c, %a" NE %c then %a NE 0 - // 1. Repeat all of the above, with order of operands reversed. - Value *LHS = Op0, *RHS = Op1; - if (!isa(LHS)) std::swap(LHS, RHS); + default: break; + case Instruction::Xor: + case Instruction::Or: + case Instruction::Add: + case Instruction::Sub: + add(Unknown, Constant::getNullValue(Ty), ICmpInst::ICMP_EQ, NewContext); + break; + case Instruction::UDiv: + case Instruction::SDiv: + if (Unknown == Op0) break; // otherwise, fallthrough + case Instruction::And: + case Instruction::Mul: + Constant *One = NULL; + if (isa(Unknown)) + One = ConstantInt::get(Ty, 1); + else if (isa(Unknown)) + One = ConstantBool::getTrue(); - if (ConstantBool *CB = dyn_cast(Canonical)) { - if (ConstantBool *A = dyn_cast(LHS)) - addEqual(RHS, ConstantBool::get(A->getValue() ^ - CB->getValue())); - } - if (Canonical == LHS) { - if (isa(Canonical)) - addEqual(RHS, Constant::getNullValue(Ty)); - } else if (isNotEqual(LHS, Canonical)) { - addNotEqual(RHS, Constant::getNullValue(Ty)); - } - } break; - default: + if (One) add(Unknown, One, ICmpInst::ICMP_EQ, NewContext); break; } + } - // "%x = add int %y, %z" and %x EQ %y then %z EQ 0 - // "%x = mul int %y, %z" and %x EQ %y then %z EQ 1 - // 1. Repeat all of the above, with order of operands reversed. - // "%x = fdiv float %y, %z" and %x EQ %y then %z EQ 1 - Value *Known = Op0, *Unknown = Op1; - if (Known != BO) std::swap(Known, Unknown); - if (Known == BO) { - switch (BO->getOpcode()) { - default: break; - case Instruction::Xor: - case Instruction::Or: - case Instruction::Add: - case Instruction::Sub: - if (!Ty->isFloatingPoint()) - addEqual(Unknown, Constant::getNullValue(Ty)); - break; - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::FDiv: - if (Unknown == Op0) break; // otherwise, fallthrough - case Instruction::And: - case Instruction::Mul: - Constant *One = NULL; - if (isa(Unknown)) - One = ConstantInt::get(Ty, 1); - else if (isa(Unknown)) - One = ConstantFP::get(Ty, 1); - else if (isa(Unknown)) - One = ConstantBool::getTrue(); + // TODO: "%a = add int %b, 1" and %b > %z then %a >= %z. - if (One) addEqual(Unknown, One); - break; + } else if (ICmpInst *IC = dyn_cast(I)) { + // "%a = icmp ult %b, %c" and %b u< %c then %a EQ true + // "%a = icmp ult %b, %c" and %b u>= %c then %a EQ false + // etc. + + Value *Op0 = IG.canonicalize(IC->getOperand(0), Top); + Value *Op1 = IG.canonicalize(IC->getOperand(1), Top); + + ICmpInst::Predicate Pred = IC->getPredicate(); + if (isRelatedBy(Op0, Op1, Pred)) { + add(IC, ConstantBool::getTrue(), ICmpInst::ICMP_EQ, NewContext); + } else if (isRelatedBy(Op0, Op1, ICmpInst::getInversePredicate(Pred))) { + add(IC, ConstantBool::getFalse(), ICmpInst::ICMP_EQ, NewContext); + } + + // TODO: make the predicate more strict, if possible. + + } else if (SelectInst *SI = dyn_cast(I)) { + // Given: "%a = select bool %x, int %b, int %c" + // %x EQ true then %a EQ %b + // %x EQ false then %a EQ %c + // %b EQ %c then %a EQ %b + + Value *Canonical = IG.canonicalize(SI->getCondition(), Top); + if (Canonical == ConstantBool::getTrue()) { + add(SI, SI->getTrueValue(), ICmpInst::ICMP_EQ, NewContext); + } else if (Canonical == ConstantBool::getFalse()) { + add(SI, SI->getFalseValue(), ICmpInst::ICMP_EQ, NewContext); + } else if (IG.canonicalize(SI->getTrueValue(), Top) == + IG.canonicalize(SI->getFalseValue(), Top)) { + add(SI, SI->getTrueValue(), ICmpInst::ICMP_EQ, NewContext); + } + } + // TODO: CastInst "%a = cast ... %b" where %b is EQ or NE a constant. + } + + /// solve - process the work queue + /// Return false if a logical contradiction occurs. + void solve() { + //DOUT << "WorkList entry, size: " << WorkList.size() << "\n"; + while (!WorkList.empty()) { + //DOUT << "WorkList size: " << WorkList.size() << "\n"; + + Operation &O = WorkList.front(); + if (O.Context) { + TopInst = O.Context; + Top = Forest->getNodeForBlock(TopInst->getParent()); + } + O.LHS = IG.canonicalize(O.LHS, Top); + O.RHS = IG.canonicalize(O.RHS, Top); + + assert(O.LHS == IG.canonicalize(O.LHS, Top) && "Canonicalize isn't."); + assert(O.RHS == IG.canonicalize(O.RHS, Top) && "Canonicalize isn't."); + + DOUT << "solving " << *O.LHS << " " << O.Op << " " << *O.RHS; + if (O.Context) DOUT << " context: " << *O.Context; + else DOUT << " default context"; + DOUT << "\n"; + + DEBUG(IG.dump()); + + // TODO: actually check the constants and add to UB. + if (isa(O.LHS) && isa(O.RHS)) { + WorkList.pop_front(); + continue; + } + + if (O.Op == ICmpInst::ICMP_EQ) { + if (!makeEqual(O.LHS, O.RHS)) + UB.mark(TopBB); + } else { + LatticeVal LV = cmpInstToLattice(O.Op); + + if ((LV & EQ_BIT) && + isRelatedBy(O.LHS, O.RHS, ICmpInst::getSwappedPredicate(O.Op))) { + if (!makeEqual(O.LHS, O.RHS)) + UB.mark(TopBB); + } else { + if (isRelatedBy(O.LHS, O.RHS, ICmpInst::getInversePredicate(O.Op))){ + DOUT << "inequality contradiction!\n"; + WorkList.pop_front(); + continue; + } + + unsigned n1 = IG.getOrInsertNode(O.LHS, Top); + unsigned n2 = IG.getOrInsertNode(O.RHS, Top); + + if (n1 == n2) { + if (O.Op != ICmpInst::ICMP_UGE && O.Op != ICmpInst::ICMP_ULE && + O.Op != ICmpInst::ICMP_SGE && O.Op != ICmpInst::ICMP_SLE) + UB.mark(TopBB); + + WorkList.pop_front(); + continue; + } + + if (IG.isRelatedBy(n1, n2, Top, LV)) { + WorkList.pop_front(); + continue; + } + + IG.addInequality(n1, n2, Top, LV); + + if (Instruction *I1 = dyn_cast(O.LHS)) defToOps(I1); + if (isa(O.LHS) || isa(O.LHS)) { + for (Value::use_iterator UI = O.LHS->use_begin(), + UE = O.LHS->use_end(); UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) { + opsToDef(I); + } + } + } + if (Instruction *I2 = dyn_cast(O.RHS)) defToOps(I2); + if (isa(O.RHS) || isa(O.RHS)) { + for (Value::use_iterator UI = O.RHS->use_begin(), + UE = O.RHS->use_end(); UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast(TheUse.getUser())) { + opsToDef(I); + } + } } } - } else if (FCmpInst *CI = dyn_cast(I)) { - Value *Op0 = cIG.canonicalize(CI->getOperand(0)), - *Op1 = cIG.canonicalize(CI->getOperand(1)); - - ConstantFP *CI1 = dyn_cast(Op0), - *CI2 = dyn_cast(Op1); - - if (CI1 && CI2) - addEqual(CI, ConstantExpr::getFCmp(CI->getPredicate(), CI1, CI2)); - - switch (CI->getPredicate()) { - case FCmpInst::FCMP_OEQ: - case FCmpInst::FCMP_UEQ: - // "eq int %a, %b" EQ true then %a EQ %b - // "eq int %a, %b" EQ false then %a NE %b - if (Canonical == ConstantBool::getTrue()) - addEqual(Op0, Op1); - else if (Canonical == ConstantBool::getFalse()) - addNotEqual(Op0, Op1); - - // %a EQ %b then "eq int %a, %b" EQ true - // %a NE %b then "eq int %a, %b" EQ false - if (isEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isNotEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - case FCmpInst::FCMP_ONE: - case FCmpInst::FCMP_UNE: - // "ne int %a, %b" EQ true then %a NE %b - // "ne int %a, %b" EQ false then %a EQ %b - if (Canonical == ConstantBool::getTrue()) - addNotEqual(Op0, Op1); - else if (Canonical == ConstantBool::getFalse()) - addEqual(Op0, Op1); - - // %a EQ %b then "ne int %a, %b" EQ false - // %a NE %b then "ne int %a, %b" EQ true - if (isEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - else if (isNotEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - - break; - case FCmpInst::FCMP_ULT: - case FCmpInst::FCMP_OLT: - // "lt int %a, %b" EQ true then %a LT %b - // "lt int %a, %b" EQ false then %b LE %a - if (Canonical == ConstantBool::getTrue()) - addLess(Op0, Op1); - else if (Canonical == ConstantBool::getFalse()) - addLessEqual(Op1, Op0); - - // %a LT %b then "lt int %a, %b" EQ true - // %a GE %b then "lt int %a, %b" EQ false - if (isLess(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isGreaterEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - case FCmpInst::FCMP_ULE: - case FCmpInst::FCMP_OLE: - // "le int %a, %b" EQ true then %a LE %b - // "le int %a, %b" EQ false then %b LT %a - if (Canonical == ConstantBool::getTrue()) - addLessEqual(Op0, Op1); - else if (Canonical == ConstantBool::getFalse()) - addLess(Op1, Op0); - - // %a LE %b then "le int %a, %b" EQ true - // %a GT %b then "le int %a, %b" EQ false - if (isLessEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isGreater(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - case FCmpInst::FCMP_UGT: - case FCmpInst::FCMP_OGT: - // "gt int %a, %b" EQ true then %b LT %a - // "gt int %a, %b" EQ false then %a LE %b - if (Canonical == ConstantBool::getTrue()) - addLess(Op1, Op0); - else if (Canonical == ConstantBool::getFalse()) - addLessEqual(Op0, Op1); - - // %a GT %b then "gt int %a, %b" EQ true - // %a LE %b then "gt int %a, %b" EQ false - if (isGreater(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isLessEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - case FCmpInst::FCMP_UGE: - case FCmpInst::FCMP_OGE: - // "ge int %a, %b" EQ true then %b LE %a - // "ge int %a, %b" EQ false then %a LT %b - if (Canonical == ConstantBool::getTrue()) - addLessEqual(Op1, Op0); - else if (Canonical == ConstantBool::getFalse()) - addLess(Op0, Op1); - - // %a GE %b then "ge int %a, %b" EQ true - // %a LT %b then "lt int %a, %b" EQ false - if (isGreaterEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isLess(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - default: - break; - } - - // "%x = add int %y, %z" and %x EQ %y then %z EQ 0 - // "%x = mul int %y, %z" and %x EQ %y then %z EQ 1 - // 1. Repeat all of the above, with order of operands reversed. - // "%x = fdiv float %y, %z" and %x EQ %y then %z EQ 1 - Value *Known = Op0, *Unknown = Op1; - if (Known != BO) std::swap(Known, Unknown); - } else if (ICmpInst *CI = dyn_cast(I)) { - Value *Op0 = cIG.canonicalize(CI->getOperand(0)), - *Op1 = cIG.canonicalize(CI->getOperand(1)); - - ConstantIntegral *CI1 = dyn_cast(Op0), - *CI2 = dyn_cast(Op1); - - if (CI1 && CI2) - addEqual(CI, ConstantExpr::getICmp(CI->getPredicate(), CI1, CI2)); - - switch (CI->getPredicate()) { - case ICmpInst::ICMP_EQ: - // "eq int %a, %b" EQ true then %a EQ %b - // "eq int %a, %b" EQ false then %a NE %b - if (Canonical == ConstantBool::getTrue()) - addEqual(Op0, Op1); - else if (Canonical == ConstantBool::getFalse()) - addNotEqual(Op0, Op1); - - // %a EQ %b then "eq int %a, %b" EQ true - // %a NE %b then "eq int %a, %b" EQ false - if (isEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isNotEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - case ICmpInst::ICMP_NE: - // "ne int %a, %b" EQ true then %a NE %b - // "ne int %a, %b" EQ false then %a EQ %b - if (Canonical == ConstantBool::getTrue()) - addNotEqual(Op0, Op1); - else if (Canonical == ConstantBool::getFalse()) - addEqual(Op0, Op1); - - // %a EQ %b then "ne int %a, %b" EQ false - // %a NE %b then "ne int %a, %b" EQ true - if (isEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - else if (isNotEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - - break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: - // "lt int %a, %b" EQ true then %a LT %b - // "lt int %a, %b" EQ false then %b LE %a - if (Canonical == ConstantBool::getTrue()) - addLess(Op0, Op1); - else if (Canonical == ConstantBool::getFalse()) - addLessEqual(Op1, Op0); - - // %a LT %b then "lt int %a, %b" EQ true - // %a GE %b then "lt int %a, %b" EQ false - if (isLess(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isGreaterEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - case ICmpInst::ICMP_ULE: - case ICmpInst::ICMP_SLE: - // "le int %a, %b" EQ true then %a LE %b - // "le int %a, %b" EQ false then %b LT %a - if (Canonical == ConstantBool::getTrue()) - addLessEqual(Op0, Op1); - else if (Canonical == ConstantBool::getFalse()) - addLess(Op1, Op0); - - // %a LE %b then "le int %a, %b" EQ true - // %a GT %b then "le int %a, %b" EQ false - if (isLessEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isGreater(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGT: - // "gt int %a, %b" EQ true then %b LT %a - // "gt int %a, %b" EQ false then %a LE %b - if (Canonical == ConstantBool::getTrue()) - addLess(Op1, Op0); - else if (Canonical == ConstantBool::getFalse()) - addLessEqual(Op0, Op1); - - // %a GT %b then "gt int %a, %b" EQ true - // %a LE %b then "gt int %a, %b" EQ false - if (isGreater(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isLessEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - case ICmpInst::ICMP_UGE: - case ICmpInst::ICMP_SGE: - // "ge int %a, %b" EQ true then %b LE %a - // "ge int %a, %b" EQ false then %a LT %b - if (Canonical == ConstantBool::getTrue()) - addLessEqual(Op1, Op0); - else if (Canonical == ConstantBool::getFalse()) - addLess(Op0, Op1); - - // %a GE %b then "ge int %a, %b" EQ true - // %a LT %b then "lt int %a, %b" EQ false - if (isGreaterEqual(Op0, Op1)) - addEqual(CI, ConstantBool::getTrue()); - else if (isLess(Op0, Op1)) - addEqual(CI, ConstantBool::getFalse()); - - break; - default: - break; - } - - // "%x = add int %y, %z" and %x EQ %y then %z EQ 0 - // "%x = mul int %y, %z" and %x EQ %y then %z EQ 1 - // 1. Repeat all of the above, with order of operands reversed. - // "%x = fdiv float %y, %z" and %x EQ %y then %z EQ 1 - Value *Known = Op0, *Unknown = Op1; - if (Known != BO) std::swap(Known, Unknown); - } else if (SelectInst *SI = dyn_cast(I)) { - // Given: "%a = select bool %x, int %b, int %c" - // %a EQ %b then %x EQ true - // %a EQ %c then %x EQ false - if (isEqual(I, SI->getTrueValue()) || - isNotEqual(I, SI->getFalseValue())) - addEqual(SI->getCondition(), ConstantBool::getTrue()); - else if (isEqual(I, SI->getFalseValue()) || - isNotEqual(I, SI->getTrueValue())) - addEqual(SI->getCondition(), ConstantBool::getFalse()); - - // %x EQ true then %a EQ %b - // %x EQ false then %a NE %b - if (isEqual(SI->getCondition(), ConstantBool::getTrue())) - addEqual(SI, SI->getTrueValue()); - else if (isEqual(SI->getCondition(), ConstantBool::getFalse())) - addEqual(SI, SI->getFalseValue()); } + WorkList.pop_front(); } } }; @@ -1209,16 +1392,10 @@ namespace { DominatorTree *DT; ETForest *Forest; bool modified; + InequalityGraph *IG; + UnreachableBlocks UB; - class State { - public: - BasicBlock *ToVisit; - InequalityGraph *IG; - - State(BasicBlock *BB, InequalityGraph *IG) : ToVisit(BB), IG(IG) {} - }; - - std::vector WorkList; + std::vector WorkList; public: bool runOnFunction(Function &F); @@ -1227,8 +1404,6 @@ namespace { AU.addRequiredID(BreakCriticalEdgesID); AU.addRequired(); AU.addRequired(); - AU.setPreservesCFG(); - AU.addPreservedID(BreakCriticalEdgesID); } private: @@ -1241,12 +1416,14 @@ namespace { class VISIBILITY_HIDDEN Forwards : public InstVisitor { friend class InstVisitor; PredicateSimplifier *PS; + DominatorTree::Node *DTNode; public: InequalityGraph &IG; + UnreachableBlocks &UB; - Forwards(PredicateSimplifier *PS, InequalityGraph &IG) - : PS(PS), IG(IG) {} + Forwards(PredicateSimplifier *PS, DominatorTree::Node *DTNode) + : PS(PS), DTNode(DTNode), IG(*PS->IG), UB(PS->UB) {} void visitTerminatorInst(TerminatorInst &TI); void visitBranchInst(BranchInst &BI); @@ -1257,53 +1434,55 @@ namespace { void visitStoreInst(StoreInst &SI); void visitBinaryOperator(BinaryOperator &BO); - void visitCmpInst(CmpInst &CI) {} }; // Used by terminator instructions to proceed from the current basic // block to the next. Verifies that "current" dominates "next", // then calls visitBasicBlock. - void proceedToSuccessors(const InequalityGraph &IG, BasicBlock *BBCurrent) { - DominatorTree::Node *Current = DT->getNode(BBCurrent); + void proceedToSuccessors(DominatorTree::Node *Current) { for (DominatorTree::Node::iterator I = Current->begin(), E = Current->end(); I != E; ++I) { - //visitBasicBlock((*I)->getBlock(), IG); - WorkList.push_back(State((*I)->getBlock(), new InequalityGraph(IG))); + WorkList.push_back(*I); } } - void proceedToSuccessor(InequalityGraph *NextIG, BasicBlock *Next) { - //visitBasicBlock(Next, NextIG); - WorkList.push_back(State(Next, NextIG)); + void proceedToSuccessor(DominatorTree::Node *Next) { + WorkList.push_back(Next); } // Visits each instruction in the basic block. - void visitBasicBlock(BasicBlock *BB, InequalityGraph &IG) { - DOUT << "Entering Basic Block: " << BB->getName() << "\n"; + void visitBasicBlock(DominatorTree::Node *Node) { + BasicBlock *BB = Node->getBlock(); + ETNode *ET = Forest->getNodeForBlock(BB); + DOUT << "Entering Basic Block: " << BB->getName() << " (" << ET << ")\n"; for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { - visitInstruction(I++, IG); + visitInstruction(I++, Node, ET); } } // Tries to simplify each Instruction and add new properties to // the PropertySet. - void visitInstruction(Instruction *I, InequalityGraph &IG) { + void visitInstruction(Instruction *I, DominatorTree::Node *DT, ETNode *ET) { DOUT << "Considering instruction " << *I << "\n"; - DEBUG(IG.debug(*cerr.stream())); + DEBUG(IG->dump()); - // Sometimes instructions are made dead due to earlier analysis. + // Sometimes instructions are killed in earlier analysis. if (isInstructionTriviallyDead(I)) { + ++NumSimple; + modified = true; + IG->remove(I); I->eraseFromParent(); return; } // Try to replace the whole instruction. - Value *V = IG.canonicalize(I); + Value *V = IG->canonicalize(I, ET); + assert(V == I && "Late instruction canonicalization."); if (V != I) { modified = true; ++NumInstruction; DOUT << "Removing " << *I << ", replacing with " << *V << "\n"; - IG.remove(I); + IG->remove(I); I->replaceAllUsesWith(V); I->eraseFromParent(); return; @@ -1312,7 +1491,8 @@ namespace { // Try to substitute operands. for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { Value *Oper = I->getOperand(i); - Value *V = IG.canonicalize(Oper); + Value *V = IG->canonicalize(Oper, ET); + assert(V == Oper && "Late operand canonicalization."); if (V != Oper) { modified = true; ++NumVarsReplaced; @@ -1322,10 +1502,10 @@ namespace { } } - //DOUT << "push (%" << I->getParent()->getName() << ")\n"; - Forwards visit(this, IG); + DOUT << "push (%" << I->getParent()->getName() << ")\n"; + Forwards visit(this, DT); visit.visit(*I); - //DOUT << "pop (%" << I->getParent()->getName() << ")\n"; + DOUT << "pop (%" << I->getParent()->getName() << ")\n"; } }; @@ -1333,64 +1513,68 @@ namespace { DT = &getAnalysis(); Forest = &getAnalysis(); + Forest->updateDFSNumbers(); // XXX: should only act when numbers are out of date + DOUT << "Entering Function: " << F.getName() << "\n"; modified = false; - WorkList.push_back(State(DT->getRoot(), new InequalityGraph())); + BasicBlock *RootBlock = &F.getEntryBlock(); + IG = new InequalityGraph(Forest->getNodeForBlock(RootBlock)); + WorkList.push_back(DT->getRootNode()); do { - State S = WorkList.back(); + DominatorTree::Node *DTNode = WorkList.back(); WorkList.pop_back(); - visitBasicBlock(S.ToVisit, *S.IG); - delete S.IG; + if (!UB.isDead(DTNode->getBlock())) visitBasicBlock(DTNode); } while (!WorkList.empty()); - //DEBUG(F.viewCFG()); + delete IG; + + modified |= UB.kill(); return modified; } void PredicateSimplifier::Forwards::visitTerminatorInst(TerminatorInst &TI) { - PS->proceedToSuccessors(IG, TI.getParent()); + PS->proceedToSuccessors(DTNode); } void PredicateSimplifier::Forwards::visitBranchInst(BranchInst &BI) { - BasicBlock *BB = BI.getParent(); - if (BI.isUnconditional()) { - PS->proceedToSuccessors(IG, BB); + PS->proceedToSuccessors(DTNode); return; } Value *Condition = BI.getCondition(); - BasicBlock *TrueDest = BI.getSuccessor(0), - *FalseDest = BI.getSuccessor(1); + BasicBlock *TrueDest = BI.getSuccessor(0); + BasicBlock *FalseDest = BI.getSuccessor(1); - if (isa(Condition) || TrueDest == FalseDest) { - PS->proceedToSuccessors(IG, BB); + if (isa(Condition) || TrueDest == FalseDest) { + PS->proceedToSuccessors(DTNode); return; } - DominatorTree::Node *Node = PS->DT->getNode(BB); - for (DominatorTree::Node::iterator I = Node->begin(), E = Node->end(); + for (DominatorTree::Node::iterator I = DTNode->begin(), E = DTNode->end(); I != E; ++I) { BasicBlock *Dest = (*I)->getBlock(); - InequalityGraph *DestProperties = new InequalityGraph(IG); - VRPSolver Solver(*DestProperties, PS->Forest, Dest); + DOUT << "Branch thinking about %" << Dest->getName() + << "(" << PS->Forest->getNodeForBlock(Dest) << ")\n"; if (Dest == TrueDest) { - DOUT << "(" << BB->getName() << ") true set:\n"; - if (!Solver.addEqual(ConstantBool::getTrue(), Condition)) continue; - Solver.solve(); - DEBUG(DestProperties->debug(*cerr.stream())); + DOUT << "(" << DTNode->getBlock()->getName() << ") true set:\n"; + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, Dest); + VRP.add(ConstantBool::getTrue(), Condition, ICmpInst::ICMP_EQ); + VRP.solve(); + DEBUG(IG.dump()); } else if (Dest == FalseDest) { - DOUT << "(" << BB->getName() << ") false set:\n"; - if (!Solver.addEqual(ConstantBool::getFalse(), Condition)) continue; - Solver.solve(); - DEBUG(DestProperties->debug(*cerr.stream())); + DOUT << "(" << DTNode->getBlock()->getName() << ") false set:\n"; + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, Dest); + VRP.add(ConstantBool::getFalse(), Condition, ICmpInst::ICMP_EQ); + VRP.solve(); + DEBUG(IG.dump()); } - PS->proceedToSuccessor(DestProperties, Dest); + PS->proceedToSuccessor(*I); } } @@ -1399,31 +1583,30 @@ namespace { // Set the EQProperty in each of the cases BBs, and the NEProperties // in the default BB. - // InequalityGraph DefaultProperties(IG); - DominatorTree::Node *Node = PS->DT->getNode(SI.getParent()); - for (DominatorTree::Node::iterator I = Node->begin(), E = Node->end(); + for (DominatorTree::Node::iterator I = DTNode->begin(), E = DTNode->end(); I != E; ++I) { BasicBlock *BB = (*I)->getBlock(); + DOUT << "Switch thinking about BB %" << BB->getName() + << "(" << PS->Forest->getNodeForBlock(BB) << ")\n"; - InequalityGraph *BBProperties = new InequalityGraph(IG); - VRPSolver Solver(*BBProperties, PS->Forest, BB); + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, BB); if (BB == SI.getDefaultDest()) { for (unsigned i = 1, e = SI.getNumCases(); i < e; ++i) if (SI.getSuccessor(i) != BB) - if (!Solver.addNotEqual(Condition, SI.getCaseValue(i))) continue; - Solver.solve(); + VRP.add(Condition, SI.getCaseValue(i), ICmpInst::ICMP_NE); + VRP.solve(); } else if (ConstantInt *CI = SI.findCaseDest(BB)) { - if (!Solver.addEqual(Condition, CI)) continue; - Solver.solve(); + VRP.add(Condition, CI, ICmpInst::ICMP_EQ); + VRP.solve(); } - PS->proceedToSuccessor(BBProperties, BB); + PS->proceedToSuccessor(*I); } } void PredicateSimplifier::Forwards::visitAllocaInst(AllocaInst &AI) { - VRPSolver VRP(IG, PS->Forest, AI.getParent()); - VRP.addNotEqual(Constant::getNullValue(AI.getType()), &AI); + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &AI); + VRP.add(Constant::getNullValue(AI.getType()), &AI, ICmpInst::ICMP_NE); VRP.solve(); } @@ -1432,8 +1615,8 @@ namespace { // avoid "load uint* null" -> null NE null. if (isa(Ptr)) return; - VRPSolver VRP(IG, PS->Forest, LI.getParent()); - VRP.addNotEqual(Constant::getNullValue(Ptr->getType()), Ptr); + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &LI); + VRP.add(Constant::getNullValue(Ptr->getType()), Ptr, ICmpInst::ICMP_NE); VRP.solve(); } @@ -1441,8 +1624,8 @@ namespace { Value *Ptr = SI.getPointerOperand(); if (isa(Ptr)) return; - VRPSolver VRP(IG, PS->Forest, SI.getParent()); - VRP.addNotEqual(Constant::getNullValue(Ptr->getType()), Ptr); + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &SI); + VRP.add(Constant::getNullValue(Ptr->getType()), Ptr, ICmpInst::ICMP_NE); VRP.solve(); } @@ -1452,13 +1635,12 @@ namespace { switch (ops) { case Instruction::URem: case Instruction::SRem: - case Instruction::FRem: case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::FDiv: { + case Instruction::SDiv: { Value *Divisor = BO.getOperand(1); - VRPSolver VRP(IG, PS->Forest, BO.getParent()); - VRP.addNotEqual(Constant::getNullValue(Divisor->getType()), Divisor); + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &BO); + VRP.add(Constant::getNullValue(Divisor->getType()), Divisor, + ICmpInst::ICMP_NE); VRP.solve(); break; } @@ -1467,7 +1649,6 @@ namespace { } } - RegisterPass X("predsimplify", "Predicate Simplifier"); } diff --git a/test/Transforms/PredicateSimplifier/2007-01-04-SelectSwitch.ll b/test/Transforms/PredicateSimplifier/2007-01-04-SelectSwitch.ll new file mode 100644 index 00000000000..141d4e27c69 --- /dev/null +++ b/test/Transforms/PredicateSimplifier/2007-01-04-SelectSwitch.ll @@ -0,0 +1,19 @@ +; RUN: llvm-upgrade < %s | llvm-as | opt -predsimplify -disable-output + +void %ercMarkCurrMBConcealed(int %comp) { +entry: + %tmp5 = icmp slt int %comp, 0 ; [#uses=2] + %comp_addr.0 = select bool %tmp5, int 0, int %comp ; [#uses=1] + switch int %comp_addr.0, label %return [ + int 0, label %bb + ] + +bb: ; preds = %entry + br bool %tmp5, label %bb87.bb97_crit_edge.critedge, label %return + +bb87.bb97_crit_edge.critedge: ; preds = %bb + ret void + +return: ; preds = %bb, %entry + ret void +}