diff --git a/include/llvm/CodeGen/LiveIntervalAnalysis.h b/include/llvm/CodeGen/LiveIntervalAnalysis.h index dd521a40c0f..292c5111209 100644 --- a/include/llvm/CodeGen/LiveIntervalAnalysis.h +++ b/include/llvm/CodeGen/LiveIntervalAnalysis.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace llvm { @@ -39,10 +40,9 @@ namespace llvm { typedef std::pair Range; typedef std::vector Ranges; unsigned reg; // the register of this interval - unsigned hint; float weight; // weight of this interval (number of uses // * 10^loopDepth) - Ranges ranges; // the ranges this register is valid + Ranges ranges; // the ranges in which this register is live Interval(unsigned r); @@ -66,10 +66,12 @@ namespace llvm { void addRange(unsigned start, unsigned end); - private: - void mergeRangesForward(Ranges::iterator it); + void join(const Interval& other); - void mergeRangesBackward(Ranges::iterator it); + private: + Ranges::iterator mergeRangesForward(Ranges::iterator it); + + Ranges::iterator mergeRangesBackward(Ranges::iterator it); }; struct StartPointComp { @@ -85,6 +87,7 @@ namespace llvm { }; typedef std::list Intervals; + typedef std::map Reg2RegMap; typedef std::vector MachineBasicBlockPtrs; private: @@ -104,11 +107,15 @@ namespace llvm { typedef std::map Reg2IntervalMap; Reg2IntervalMap r2iMap_; + Reg2RegMap r2rMap_; + Intervals intervals_; public: virtual void getAnalysisUsage(AnalysisUsage &AU) const; + Intervals& getIntervals() { return intervals_; } + MachineBasicBlockPtrs getOrderedMachineBasicBlockPtrs() const { MachineBasicBlockPtrs result; for (MbbIndex2MbbMap::const_iterator @@ -119,6 +126,13 @@ namespace llvm { return result; } + const Reg2RegMap& getJoinedRegMap() const { + return r2rMap_; + } + + /// rep - returns the representative of this register + unsigned rep(unsigned reg); + private: /// runOnMachineFunction - pass entry point bool runOnMachineFunction(MachineFunction&); @@ -126,6 +140,8 @@ namespace llvm { /// computeIntervals - compute live intervals void computeIntervals(); + /// joinIntervals - join compatible live intervals + void joinIntervals(); /// handleRegisterDef - update intervals for a register def /// (calls handlePhysicalRegisterDef and diff --git a/include/llvm/CodeGen/LiveIntervals.h b/include/llvm/CodeGen/LiveIntervals.h index dd521a40c0f..292c5111209 100644 --- a/include/llvm/CodeGen/LiveIntervals.h +++ b/include/llvm/CodeGen/LiveIntervals.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace llvm { @@ -39,10 +40,9 @@ namespace llvm { typedef std::pair Range; typedef std::vector Ranges; unsigned reg; // the register of this interval - unsigned hint; float weight; // weight of this interval (number of uses // * 10^loopDepth) - Ranges ranges; // the ranges this register is valid + Ranges ranges; // the ranges in which this register is live Interval(unsigned r); @@ -66,10 +66,12 @@ namespace llvm { void addRange(unsigned start, unsigned end); - private: - void mergeRangesForward(Ranges::iterator it); + void join(const Interval& other); - void mergeRangesBackward(Ranges::iterator it); + private: + Ranges::iterator mergeRangesForward(Ranges::iterator it); + + Ranges::iterator mergeRangesBackward(Ranges::iterator it); }; struct StartPointComp { @@ -85,6 +87,7 @@ namespace llvm { }; typedef std::list Intervals; + typedef std::map Reg2RegMap; typedef std::vector MachineBasicBlockPtrs; private: @@ -104,11 +107,15 @@ namespace llvm { typedef std::map Reg2IntervalMap; Reg2IntervalMap r2iMap_; + Reg2RegMap r2rMap_; + Intervals intervals_; public: virtual void getAnalysisUsage(AnalysisUsage &AU) const; + Intervals& getIntervals() { return intervals_; } + MachineBasicBlockPtrs getOrderedMachineBasicBlockPtrs() const { MachineBasicBlockPtrs result; for (MbbIndex2MbbMap::const_iterator @@ -119,6 +126,13 @@ namespace llvm { return result; } + const Reg2RegMap& getJoinedRegMap() const { + return r2rMap_; + } + + /// rep - returns the representative of this register + unsigned rep(unsigned reg); + private: /// runOnMachineFunction - pass entry point bool runOnMachineFunction(MachineFunction&); @@ -126,6 +140,8 @@ namespace llvm { /// computeIntervals - compute live intervals void computeIntervals(); + /// joinIntervals - join compatible live intervals + void joinIntervals(); /// handleRegisterDef - update intervals for a register def /// (calls handlePhysicalRegisterDef and diff --git a/lib/CodeGen/LiveIntervalAnalysis.cpp b/lib/CodeGen/LiveIntervalAnalysis.cpp index eb9cf38b82a..fb11ff093bb 100644 --- a/lib/CodeGen/LiveIntervalAnalysis.cpp +++ b/lib/CodeGen/LiveIntervalAnalysis.cpp @@ -30,6 +30,7 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetRegInfo.h" #include "llvm/Support/CFG.h" +#include "Support/CommandLine.h" #include "Support/Debug.h" #include "Support/DepthFirstIterator.h" #include "Support/Statistic.h" @@ -44,6 +45,11 @@ namespace { "Live Interval Analysis"); Statistic<> numIntervals("liveintervals", "Number of intervals"); + + cl::opt + join("join-liveintervals", + cl::desc("Join compatible live intervals"), + cl::init(false)); }; void LiveIntervals::getAnalysisUsage(AnalysisUsage &AU) const @@ -69,6 +75,7 @@ bool LiveIntervals::runOnMachineFunction(MachineFunction &fn) { mi2iMap_.clear(); r2iMap_.clear(); r2iMap_.clear(); + r2rMap_.clear(); intervals_.clear(); // number MachineInstrs @@ -95,18 +102,17 @@ bool LiveIntervals::runOnMachineFunction(MachineFunction &fn) { const LoopInfo& loopInfo = getAnalysis(); const TargetInstrInfo& tii = tm_->getInstrInfo(); - for (MbbIndex2MbbMap::iterator - it = mbbi2mbbMap_.begin(), itEnd = mbbi2mbbMap_.end(); - it != itEnd; ++it) { - MachineBasicBlock* mbb = it->second; - + for (MachineFunction::const_iterator mbbi = mf_->begin(), + mbbe = mf_->end(); mbbi != mbbe; ++mbbi) { + const MachineBasicBlock* mbb = mbbi; unsigned loopDepth = loopInfo.getLoopDepth(mbb->getBasicBlock()); - for (MachineBasicBlock::iterator mi = mbb->begin(), miEnd = mbb->end(); - mi != miEnd; ++mi) { - MachineInstr* instr = *mi; - for (int i = instr->getNumOperands() - 1; i >= 0; --i) { - MachineOperand& mop = instr->getOperand(i); + for (MachineBasicBlock::const_iterator mii = mbb->begin(), + mie = mbb->end(); mii != mie; ++mii) { + MachineInstr* mi = *mii; + + for (int i = mi->getNumOperands() - 1; i >= 0; --i) { + MachineOperand& mop = mi->getOperand(i); if (!mop.isVirtualRegister()) continue; @@ -119,6 +125,9 @@ bool LiveIntervals::runOnMachineFunction(MachineFunction &fn) { } } + // join intervals if requested + if (join) joinIntervals(); + numIntervals += intervals_.size(); return true; @@ -326,48 +335,111 @@ void LiveIntervals::computeIntervals() std::ostream_iterator(std::cerr, "\n"))); } +unsigned LiveIntervals::rep(unsigned reg) +{ + Reg2RegMap::iterator it = r2rMap_.find(reg); + if (it != r2rMap_.end()) + return it->second = rep(it->second); + return reg; +} + +void LiveIntervals::joinIntervals() +{ + DEBUG(std::cerr << "joining compatible intervals:\n"); + + const TargetInstrInfo& tii = tm_->getInstrInfo(); + + for (MachineFunction::const_iterator mbbi = mf_->begin(), + mbbe = mf_->end(); mbbi != mbbe; ++mbbi) { + const MachineBasicBlock* mbb = mbbi; + DEBUG(std::cerr << "machine basic block: " + << mbb->getBasicBlock()->getName() << "\n"); + + for (MachineBasicBlock::const_iterator mii = mbb->begin(), + mie = mbb->end(); mii != mie; ++mii) { + MachineInstr* mi = *mii; + const TargetInstrDescriptor& tid = + tm_->getInstrInfo().get(mi->getOpcode()); + DEBUG(std::cerr << "\t\tinstruction[" + << getInstructionIndex(mi) << "]: "; + mi->print(std::cerr, *tm_);); + + unsigned srcReg, dstReg; + if (tii.isMoveInstr(*mi, srcReg, dstReg) && + (srcReg >= MRegisterInfo::FirstVirtualRegister || + lv_->getAllocatablePhysicalRegisters()[srcReg]) && + (dstReg >= MRegisterInfo::FirstVirtualRegister || + lv_->getAllocatablePhysicalRegisters()[dstReg])) { + + // get representative registers + srcReg = rep(srcReg); + dstReg = rep(dstReg); + + // if they are already joined we continue + if (srcReg == dstReg) + continue; + + Reg2IntervalMap::iterator r2iSrc = r2iMap_.find(srcReg); + assert(r2iSrc != r2iMap_.end()); + Reg2IntervalMap::iterator r2iDst = r2iMap_.find(dstReg); + assert(r2iDst != r2iMap_.end()); + + Intervals::iterator srcInt = r2iSrc->second; + Intervals::iterator dstInt = r2iDst->second; + + if (srcInt->reg < MRegisterInfo::FirstVirtualRegister) { + if (dstInt->reg == srcInt->reg || + (dstInt->reg >= MRegisterInfo::FirstVirtualRegister && + !dstInt->overlaps(*srcInt))) { + srcInt->join(*dstInt); + r2iDst->second = r2iSrc->second; + r2rMap_.insert(std::make_pair(dstInt->reg, srcInt->reg)); + intervals_.erase(dstInt); + } + } + else if (dstInt->reg < MRegisterInfo::FirstVirtualRegister) { + if (srcInt->reg == dstInt->reg || + (srcInt->reg >= MRegisterInfo::FirstVirtualRegister && + !srcInt->overlaps(*dstInt))) { + dstInt->join(*srcInt); + r2iSrc->second = r2iDst->second; + r2rMap_.insert(std::make_pair(srcInt->reg, dstInt->reg)); + intervals_.erase(srcInt); + } + } + else { + const TargetRegisterClass *srcRc, *dstRc; + srcRc = mf_->getSSARegMap()->getRegClass(srcInt->reg); + dstRc = mf_->getSSARegMap()->getRegClass(dstInt->reg); + + if (srcRc == dstRc && !dstInt->overlaps(*srcInt)) { + srcInt->join(*dstInt); + r2iDst->second = r2iSrc->second; + r2rMap_.insert(std::make_pair(dstInt->reg, srcInt->reg)); + intervals_.erase(dstInt); + } + } + } + } + } + + intervals_.sort(StartPointComp()); + DEBUG(std::copy(intervals_.begin(), intervals_.end(), + std::ostream_iterator(std::cerr, "\n"))); + DEBUG(for (Reg2RegMap::const_iterator i = r2rMap_.begin(), + e = r2rMap_.end(); i != e; ++i) + std::cerr << i->first << " -> " << i->second << '\n';); + +} + LiveIntervals::Interval::Interval(unsigned r) - : reg(r), hint(0), + : reg(r), weight((r < MRegisterInfo::FirstVirtualRegister ? std::numeric_limits::max() : 0.0F)) { } -void LiveIntervals::Interval::addRange(unsigned start, unsigned end) -{ - DEBUG(std::cerr << "\t\t\tadding range: [" << start <<','<< end << ") -> "); - //assert(start < end && "invalid range?"); - Range range = std::make_pair(start, end); - Ranges::iterator it = - ranges.insert(std::upper_bound(ranges.begin(), ranges.end(), range), - range); - - mergeRangesForward(it); - mergeRangesBackward(it); - DEBUG(std::cerr << *this << '\n'); -} - -void LiveIntervals::Interval::mergeRangesForward(Ranges::iterator it) -{ - for (Ranges::iterator next = it + 1; - next != ranges.end() && it->second >= next->first; ) { - it->second = std::max(it->second, next->second); - next = ranges.erase(next); - } -} - -void LiveIntervals::Interval::mergeRangesBackward(Ranges::iterator it) -{ - for (Ranges::iterator prev = it - 1; - it != ranges.begin() && it->first <= prev->second; ) { - it->first = std::min(it->first, prev->first); - it->second = std::max(it->second, prev->second); - it = ranges.erase(prev); - prev = it - 1; - } -} - bool LiveIntervals::Interval::liveAt(unsigned index) const { Ranges::const_iterator r = ranges.begin(); @@ -409,6 +481,63 @@ bool LiveIntervals::Interval::overlaps(const Interval& other) const return false; } +void LiveIntervals::Interval::addRange(unsigned start, unsigned end) +{ + DEBUG(std::cerr << "\t\t\tadding range: [" << start <<','<< end << ") -> "); + //assert(start < end && "invalid range?"); + Range range = std::make_pair(start, end); + Ranges::iterator it = + ranges.insert(std::upper_bound(ranges.begin(), ranges.end(), range), + range); + + it = mergeRangesForward(it); + it = mergeRangesBackward(it); + DEBUG(std::cerr << "\t\t\t\tafter merging: " << *this << '\n'); +} + +void LiveIntervals::Interval::join(const LiveIntervals::Interval& other) +{ + DEBUG(std::cerr << "\t\t\t\tjoining intervals: " + << other << " and " << *this << '\n'); + Ranges::iterator cur = ranges.begin(); + + for (Ranges::const_iterator i = other.ranges.begin(), + e = other.ranges.end(); i != e; ++i) { + cur = ranges.insert(std::upper_bound(cur, ranges.end(), *i), *i); + cur = mergeRangesForward(cur); + cur = mergeRangesBackward(cur); + } + if (reg >= MRegisterInfo::FirstVirtualRegister) + weight += other.weight; + + DEBUG(std::cerr << "\t\t\t\tafter merging: " << *this << '\n'); +} + +LiveIntervals::Interval::Ranges::iterator +LiveIntervals::Interval::mergeRangesForward(Ranges::iterator it) +{ + for (Ranges::iterator next = it + 1; + next != ranges.end() && it->second >= next->first; ) { + it->second = std::max(it->second, next->second); + next = ranges.erase(next); + } + return it; +} + +LiveIntervals::Interval::Ranges::iterator +LiveIntervals::Interval::mergeRangesBackward(Ranges::iterator it) +{ + for (Ranges::iterator prev = it - 1; + it != ranges.begin() && it->first <= prev->second; ) { + it->first = std::min(it->first, prev->first); + it->second = std::max(it->second, prev->second); + it = ranges.erase(prev); + prev = it - 1; + } + + return it; +} + std::ostream& llvm::operator<<(std::ostream& os, const LiveIntervals::Interval& li) { diff --git a/lib/CodeGen/LiveIntervalAnalysis.h b/lib/CodeGen/LiveIntervalAnalysis.h index dd521a40c0f..292c5111209 100644 --- a/lib/CodeGen/LiveIntervalAnalysis.h +++ b/lib/CodeGen/LiveIntervalAnalysis.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace llvm { @@ -39,10 +40,9 @@ namespace llvm { typedef std::pair Range; typedef std::vector Ranges; unsigned reg; // the register of this interval - unsigned hint; float weight; // weight of this interval (number of uses // * 10^loopDepth) - Ranges ranges; // the ranges this register is valid + Ranges ranges; // the ranges in which this register is live Interval(unsigned r); @@ -66,10 +66,12 @@ namespace llvm { void addRange(unsigned start, unsigned end); - private: - void mergeRangesForward(Ranges::iterator it); + void join(const Interval& other); - void mergeRangesBackward(Ranges::iterator it); + private: + Ranges::iterator mergeRangesForward(Ranges::iterator it); + + Ranges::iterator mergeRangesBackward(Ranges::iterator it); }; struct StartPointComp { @@ -85,6 +87,7 @@ namespace llvm { }; typedef std::list Intervals; + typedef std::map Reg2RegMap; typedef std::vector MachineBasicBlockPtrs; private: @@ -104,11 +107,15 @@ namespace llvm { typedef std::map Reg2IntervalMap; Reg2IntervalMap r2iMap_; + Reg2RegMap r2rMap_; + Intervals intervals_; public: virtual void getAnalysisUsage(AnalysisUsage &AU) const; + Intervals& getIntervals() { return intervals_; } + MachineBasicBlockPtrs getOrderedMachineBasicBlockPtrs() const { MachineBasicBlockPtrs result; for (MbbIndex2MbbMap::const_iterator @@ -119,6 +126,13 @@ namespace llvm { return result; } + const Reg2RegMap& getJoinedRegMap() const { + return r2rMap_; + } + + /// rep - returns the representative of this register + unsigned rep(unsigned reg); + private: /// runOnMachineFunction - pass entry point bool runOnMachineFunction(MachineFunction&); @@ -126,6 +140,8 @@ namespace llvm { /// computeIntervals - compute live intervals void computeIntervals(); + /// joinIntervals - join compatible live intervals + void joinIntervals(); /// handleRegisterDef - update intervals for a register def /// (calls handlePhysicalRegisterDef and diff --git a/lib/CodeGen/RegAllocLinearScan.cpp b/lib/CodeGen/RegAllocLinearScan.cpp index 5f1290d233c..d25dc345291 100644 --- a/lib/CodeGen/RegAllocLinearScan.cpp +++ b/lib/CodeGen/RegAllocLinearScan.cpp @@ -32,12 +32,15 @@ using namespace llvm; namespace { Statistic<> numSpilled ("ra-linearscan", "Number of registers spilled"); Statistic<> numReloaded("ra-linearscan", "Number of registers reloaded"); + Statistic<> numPeep ("ra-linearscan", + "Number of identity moves eliminated"); class RA : public MachineFunctionPass { private: MachineFunction* mf_; const TargetMachine* tm_; const MRegisterInfo* mri_; + LiveIntervals* li_; MachineFunction::iterator currentMbb_; MachineBasicBlock::iterator currentInstr_; typedef std::vector IntervalPtrs; @@ -203,8 +206,8 @@ bool RA::runOnMachineFunction(MachineFunction &fn) { mf_ = &fn; tm_ = &fn.getTarget(); mri_ = tm_->getRegisterInfo(); - - initIntervalSets(getAnalysis().getIntervals()); + li_ = &getAnalysis(); + initIntervalSets(li_->getIntervals()); v2pMap_.clear(); v2ssMap_.clear(); @@ -231,6 +234,7 @@ bool RA::runOnMachineFunction(MachineFunction &fn) { reserved_[29] = true; /* FP6 */ // linear scan algorithm + DEBUG(std::cerr << "Machine Function\n"); DEBUG(printIntervals("\tunhandled", unhandled_.begin(), unhandled_.end())); DEBUG(printIntervals("\tfixed", fixed_.begin(), fixed_.end())); @@ -324,15 +328,54 @@ bool RA::runOnMachineFunction(MachineFunction &fn) { active_.clear(); inactive_.clear(); - DEBUG(std::cerr << "finished register allocation\n"); + typedef LiveIntervals::Reg2RegMap Reg2RegMap; + const Reg2RegMap& r2rMap = li_->getJoinedRegMap(); + DEBUG(printVirtRegAssignment()); + DEBUG(std::cerr << "Performing coalescing on joined intervals\n"); + // perform coalescing if we were passed joined intervals + for(Reg2RegMap::const_iterator i = r2rMap.begin(), e = r2rMap.end(); + i != e; ++i) { + unsigned reg = i->first; + unsigned rep = li_->rep(reg); + + assert((rep < MRegisterInfo::FirstVirtualRegister || + v2pMap_.find(rep) != v2pMap_.end() || + v2ssMap_.find(rep) != v2ssMap_.end()) && + "representative register is not allocated!"); + + assert(reg >= MRegisterInfo::FirstVirtualRegister && + v2pMap_.find(reg) == v2pMap_.end() && + v2ssMap_.find(reg) == v2ssMap_.end() && + "coalesced register is already allocated!"); + + if (rep < MRegisterInfo::FirstVirtualRegister) { + v2pMap_.insert(std::make_pair(reg, rep)); + } + else { + Virt2PhysMap::const_iterator pr = v2pMap_.find(rep); + if (pr != v2pMap_.end()) { + v2pMap_.insert(std::make_pair(reg, pr->second)); + } + else { + Virt2StackSlotMap::const_iterator ss = v2ssMap_.find(rep); + assert(ss != v2ssMap_.end()); + v2ssMap_.insert(std::make_pair(reg, ss->second)); + } + } + } + + DEBUG(printVirtRegAssignment()); + DEBUG(std::cerr << "finished register allocation\n"); + + const TargetInstrInfo& tii = tm_->getInstrInfo(); DEBUG(std::cerr << "Rewrite machine code:\n"); for (currentMbb_ = mf_->begin(); currentMbb_ != mf_->end(); ++currentMbb_) { instrAdded_ = 0; for (currentInstr_ = currentMbb_->begin(); - currentInstr_ != currentMbb_->end(); ++currentInstr_) { + currentInstr_ != currentMbb_->end(); ) { DEBUG(std::cerr << "\tinstruction: "; (*currentInstr_)->print(std::cerr, *tm_);); @@ -355,6 +398,21 @@ bool RA::runOnMachineFunction(MachineFunction &fn) { } } + unsigned srcReg, dstReg; + if (tii.isMoveInstr(**currentInstr_, srcReg, dstReg) && + ((srcReg < MRegisterInfo::FirstVirtualRegister && + dstReg < MRegisterInfo::FirstVirtualRegister && + srcReg == dstReg) || + (srcReg >= MRegisterInfo::FirstVirtualRegister && + dstReg >= MRegisterInfo::FirstVirtualRegister && + v2ssMap_[srcReg] == v2ssMap_[dstReg]))) { + delete *currentInstr_; + currentInstr_ = currentMbb_->erase(currentInstr_); + ++numPeep; + DEBUG(std::cerr << "\t\tdeleting instruction\n"); + continue; + } + DEBUG(std::cerr << "\t\tloading temporarily used operands to " "registers:\n"); for (unsigned i = 0, e = (*currentInstr_)->getNumOperands(); @@ -416,6 +474,7 @@ bool RA::runOnMachineFunction(MachineFunction &fn) { } --currentInstr_; // restore currentInstr_ iterator tempDefOperands_.clear(); + ++currentInstr_; } }