Generalize createSCEV to be able to form SCEV expressions from

ConstantExprs.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@52615 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Dan Gohman 2008-06-22 19:56:46 +00:00
parent 17f1972c77
commit 6c459a28ec

View File

@ -1704,118 +1704,125 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
if (!isa<IntegerType>(V->getType()))
return SE.getUnknown(V);
if (Instruction *I = dyn_cast<Instruction>(V)) {
switch (I->getOpcode()) {
case Instruction::Add:
return SE.getAddExpr(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1)));
case Instruction::Mul:
return SE.getMulExpr(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1)));
case Instruction::UDiv:
return SE.getUDivExpr(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1)));
case Instruction::Sub:
return SE.getMinusSCEV(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1)));
case Instruction::Or:
// If the RHS of the Or is a constant, we may have something like:
// X*4+1 which got turned into X*4|1. Handle this as an Add so loop
// optimizations will transparently handle this case.
//
// In order for this transformation to be safe, the LHS must be of the
// form X*(2^n) and the Or constant must be less than 2^n.
if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
SCEVHandle LHS = getSCEV(I->getOperand(0));
const APInt &CIVal = CI->getValue();
if (GetMinTrailingZeros(LHS) >=
(CIVal.getBitWidth() - CIVal.countLeadingZeros()))
return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
}
break;
case Instruction::Xor:
// If the RHS of the xor is a signbit, then this is just an add.
// Instcombine turns add of signbit into xor as a strength reduction step.
if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
if (CI->getValue().isSignBit())
return SE.getAddExpr(getSCEV(I->getOperand(0)),
getSCEV(I->getOperand(1)));
else if (CI->isAllOnesValue())
return SE.getNotSCEV(getSCEV(I->getOperand(0)));
}
break;
unsigned Opcode = Instruction::UserOp1;
if (Instruction *I = dyn_cast<Instruction>(V))
Opcode = I->getOpcode();
else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
Opcode = CE->getOpcode();
else
return SE.getUnknown(V);
case Instruction::Shl:
// Turn shift left of a constant amount into a multiply.
if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
Constant *X = ConstantInt::get(
APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
return SE.getMulExpr(getSCEV(I->getOperand(0)), getSCEV(X));
}
break;
case Instruction::Trunc:
return SE.getTruncateExpr(getSCEV(I->getOperand(0)), I->getType());
case Instruction::ZExt:
return SE.getZeroExtendExpr(getSCEV(I->getOperand(0)), I->getType());
case Instruction::SExt:
return SE.getSignExtendExpr(getSCEV(I->getOperand(0)), I->getType());
case Instruction::BitCast:
// BitCasts are no-op casts so we just eliminate the cast.
if (I->getType()->isInteger() &&
I->getOperand(0)->getType()->isInteger())
return getSCEV(I->getOperand(0));
break;
case Instruction::PHI:
return createNodeForPHI(cast<PHINode>(I));
case Instruction::Select:
// This could be a smax or umax that was lowered earlier.
// Try to recover it.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) {
Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1);
switch (ICI->getPredicate()) {
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
// -smax(-x, -y) == smin(x, y).
return SE.getNegativeSCEV(SE.getSMaxExpr(
SE.getNegativeSCEV(getSCEV(LHS)),
SE.getNegativeSCEV(getSCEV(RHS))));
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
// ~umax(~x, ~y) == umin(x, y)
return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
SE.getNotSCEV(getSCEV(RHS))));
break;
default:
break;
}
}
default: // We cannot analyze this expression.
break;
User *U = cast<User>(V);
switch (Opcode) {
case Instruction::Add:
return SE.getAddExpr(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)));
case Instruction::Mul:
return SE.getMulExpr(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)));
case Instruction::UDiv:
return SE.getUDivExpr(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)));
case Instruction::Sub:
return SE.getMinusSCEV(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)));
case Instruction::Or:
// If the RHS of the Or is a constant, we may have something like:
// X*4+1 which got turned into X*4|1. Handle this as an Add so loop
// optimizations will transparently handle this case.
//
// In order for this transformation to be safe, the LHS must be of the
// form X*(2^n) and the Or constant must be less than 2^n.
if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
SCEVHandle LHS = getSCEV(U->getOperand(0));
const APInt &CIVal = CI->getValue();
if (GetMinTrailingZeros(LHS) >=
(CIVal.getBitWidth() - CIVal.countLeadingZeros()))
return SE.getAddExpr(LHS, getSCEV(U->getOperand(1)));
}
break;
case Instruction::Xor:
// If the RHS of the xor is a signbit, then this is just an add.
// Instcombine turns add of signbit into xor as a strength reduction step.
if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
if (CI->getValue().isSignBit())
return SE.getAddExpr(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)));
else if (CI->isAllOnesValue())
return SE.getNotSCEV(getSCEV(U->getOperand(0)));
}
break;
case Instruction::Shl:
// Turn shift left of a constant amount into a multiply.
if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
Constant *X = ConstantInt::get(
APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
}
break;
case Instruction::Trunc:
return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::ZExt:
return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::SExt:
return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::BitCast:
// BitCasts are no-op casts so we just eliminate the cast.
if (U->getType()->isInteger() &&
U->getOperand(0)->getType()->isInteger())
return getSCEV(U->getOperand(0));
break;
case Instruction::PHI:
return createNodeForPHI(cast<PHINode>(U));
case Instruction::Select:
// This could be a smax or umax that was lowered earlier.
// Try to recover it.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1);
switch (ICI->getPredicate()) {
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
// -smax(-x, -y) == smin(x, y).
return SE.getNegativeSCEV(SE.getSMaxExpr(
SE.getNegativeSCEV(getSCEV(LHS)),
SE.getNegativeSCEV(getSCEV(RHS))));
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
// ~umax(~x, ~y) == umin(x, y)
return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
SE.getNotSCEV(getSCEV(RHS))));
break;
default:
break;
}
}
default: // We cannot analyze this expression.
break;
}
return SE.getUnknown(V);