diff --git a/lib/Transforms/IPO/IPConstantPropagation.cpp b/lib/Transforms/IPO/IPConstantPropagation.cpp index 4ebdaf3fb2b..0e654a50eef 100644 --- a/lib/Transforms/IPO/IPConstantPropagation.cpp +++ b/lib/Transforms/IPO/IPConstantPropagation.cpp @@ -21,6 +21,7 @@ #include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/Pass.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Support/CallSite.h" #include "llvm/Support/Compiler.h" #include "llvm/ADT/Statistic.h" @@ -140,9 +141,10 @@ bool IPCP::PropagateConstantsIntoArguments(Function &F) { } -// Check to see if this function returns a constant. If so, replace all callers -// that user the return value with the returned valued. If we can replace ALL -// callers, +// Check to see if this function returns one or more constants. If so, replace +// all callers that use those return values with the constant value. This will +// leave in the actual return values and instructions, but deadargelim will +// clean that up. bool IPCP::PropagateConstantReturn(Function &F) { if (F.getReturnType() == Type::VoidTy) return false; // No return value. @@ -156,48 +158,65 @@ bool IPCP::PropagateConstantReturn(Function &F) { SmallVector RetVals; const StructType *STy = dyn_cast(F.getReturnType()); if (STy) - RetVals.assign(STy->getNumElements(), 0); + for (unsigned i = 0, e = STy->getNumElements(); i < e; ++i) + RetVals.push_back(UndefValue::get(STy->getElementType(i))); else - RetVals.push_back(0); + RetVals.push_back(UndefValue::get(F.getReturnType())); + unsigned NumNonConstant = 0; for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) if (ReturnInst *RI = dyn_cast(BB->getTerminator())) { - assert(RetVals.size() == RI->getNumOperands() && - "Invalid ReturnInst operands!"); + // Return type does not match operand type, this is an old style multiple + // return + bool OldReturn = (F.getReturnType() != RI->getOperand(0)->getType()); + for (unsigned i = 0, e = RetVals.size(); i != e; ++i) { - if (isa(RI->getOperand(i))) - continue; // Ignore - Constant *C = dyn_cast(RI->getOperand(i)); - if (C == 0) - return false; // Does not return a constant. - + // Already found conflicting return values? Value *RV = RetVals[i]; - if (RV == 0) - RetVals[i] = C; - else if (RV != C) - return false; // Does not return the same constant. + if (!RV) + continue; + + // Find the returned value + Value *V; + if (!STy || OldReturn) + V = RI->getOperand(i); + else + V = FindInsertedValue(RI->getOperand(0), i); + + if (V) { + // Ignore undefs, we can change them into anything + if (isa(V)) + continue; + + // Try to see if all the rets return the same constant. + if (isa(V)) { + if (isa(RV)) { + // No value found yet? Try the current one. + RetVals[i] = V; + continue; + } + // Returning the same value? Good. + if (RV == V) + continue; + } + } + // Different or no known return value? Don't propagate this return + // value. + RetVals[i] = 0; + // All values non constant? Stop looking. + if (++NumNonConstant == RetVals.size()) + return false; } } - if (STy) { - for (unsigned i = 0, e = RetVals.size(); i < e; ++i) - if (RetVals[i] == 0) - RetVals[i] = UndefValue::get(STy->getElementType(i)); - } else { - assert(RetVals.size() == 1); - if (RetVals[0] == 0) - RetVals[0] = UndefValue::get(F.getReturnType()); - } - - // If we got here, the function returns a constant value. Loop over all - // users, replacing any uses of the return value with the returned constant. - bool ReplacedAllUsers = true; + // If we got here, the function returns at least one constant value. Loop + // over all users, replacing any uses of the return value with the returned + // constant. bool MadeChange = false; for (Value::use_iterator UI = F.use_begin(), E = F.use_end(); UI != E; ++UI) { // Make sure this is an invoke or call and that the use is for the callee. if (!(isa(*UI) || isa(*UI)) || UI.getOperandNo() != 0) { - ReplacedAllUsers = false; continue; } @@ -212,28 +231,32 @@ bool IPCP::PropagateConstantReturn(Function &F) { continue; } - while (!Call->use_empty()) { - GetResultInst *GR = cast(Call->use_back()); - GR->replaceAllUsesWith(RetVals[GR->getIndex()]); - GR->eraseFromParent(); - } - } - - // If we replace all users with the returned constant, and there can be no - // other callers of the function, replace the constant being returned in the - // function with an undef value. - if (ReplacedAllUsers && F.hasInternalLinkage()) { - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - if (ReturnInst *RI = dyn_cast(BB->getTerminator())) { - for (unsigned i = 0, e = RetVals.size(); i < e; ++i) { - Value *RetVal = RetVals[i]; - if (isa(RetVal)) - continue; - Value *RV = UndefValue::get(RetVal->getType()); - if (RI->getOperand(i) != RV) { - RI->setOperand(i, RV); - MadeChange = true; - } + for (Value::use_iterator I = Call->use_begin(), E = Call->use_end(); + I != E;) { + Instruction *Ins = dyn_cast(*I); + + // Increment now, so we can remove the use + ++I; + + // Not an instruction? Ignore + if (!Ins) + continue; + + // Find the index of the retval to replace with + int index = -1; + if (GetResultInst *GR = dyn_cast(Ins)) + index = GR->getIndex(); + else if (ExtractValueInst *EV = dyn_cast(Ins)) + if (EV->hasIndices()) + index = *EV->idx_begin(); + + // If this use uses a specific return value, and we have a replacement, + // replace it. + if (index != -1) { + Value *New = RetVals[index]; + if (New) { + Ins->replaceAllUsesWith(New); + Ins->eraseFromParent(); } } } diff --git a/test/Transforms/IPConstantProp/return-constants.ll b/test/Transforms/IPConstantProp/return-constants.ll index 40567f80bd8..7205c2820a7 100644 --- a/test/Transforms/IPConstantProp/return-constants.ll +++ b/test/Transforms/IPConstantProp/return-constants.ll @@ -1,20 +1,41 @@ -; RUN: llvm-as < %s | opt -ipconstprop | llvm-dis | grep {add i32 21, 21} +; RUN: llvm-as < %s | opt -ipconstprop | llvm-dis > %t +;; Check that the 21 constants got propagated properly +; RUN: cat %t | grep {%M = add i32 21, 21} +;; Check that the second return values didn't get propagated +; RUN: cat %t | grep {%N = add i32 %B, %D} -define internal {i32, i32} @foo(i1 %C) { - br i1 %C, label %T, label %F +define internal {i32, i32} @foo(i1 %Q) { + br i1 %Q, label %T, label %F T: ; preds = %0 - ret i32 21, i32 21 + ret i32 21, i32 22 F: ; preds = %0 - ret i32 21, i32 21 + ret i32 21, i32 23 } -define i32 @caller(i1 %C) { - %X = call {i32, i32} @foo( i1 %C ) +define internal {i32, i32} @bar(i1 %Q) { + %A = insertvalue { i32, i32 } undef, i32 21, 0 + br i1 %Q, label %T, label %F + +T: ; preds = %0 + %B = insertvalue { i32, i32 } %A, i32 22, 1 + ret { i32, i32 } %B + +F: ; preds = %0 + %C = insertvalue { i32, i32 } %A, i32 23, 1 + ret { i32, i32 } %C +} + +define { i32, i32 } @caller(i1 %Q) { + %X = call {i32, i32} @foo( i1 %Q ) %A = getresult {i32, i32} %X, 0 %B = getresult {i32, i32} %X, 1 - %Y = add i32 %A, %B - ret i32 %Y + %Y = call {i32, i32} @bar( i1 %Q ) + %C = extractvalue {i32, i32} %Y, 0 + %D = extractvalue {i32, i32} %Y, 1 + %M = add i32 %A, %C + %N = add i32 %B, %D + ret { i32, i32 } %X }