Refactor the ptr runtime check generation code. No functionality change.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@168568 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Nadav Rotem 2012-11-25 16:27:16 +00:00
parent 170a15e98d
commit 8c6b73666b

View File

@ -128,6 +128,10 @@ public:
}
private:
/// Add code that checks at runtime if the accessed arrays overlap.
/// Returns the comperator value or NULL if no check is needed.
Value* addRuntimeCheck(LoopVectorizationLegality *Legal,
Instruction *Loc);
/// Create an empty loop, based on the loop ranges of the old loop.
void createEmptyLoop(LoopVectorizationLegality *Legal);
/// Copy and widen the instructions from the old loop.
@ -671,6 +675,67 @@ void SingleBlockLoopVectorizer::scalarizeInstruction(Instruction *Instr) {
WidenMap[Instr] = VecResults;
}
Value*
SingleBlockLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal,
Instruction *Loc) {
LoopVectorizationLegality::RuntimePointerCheck *PtrRtCheck =
Legal->getRuntimePointerCheck();
if (!PtrRtCheck->Need)
return NULL;
Value *MemoryRuntimeCheck = 0;
unsigned NumPointers = PtrRtCheck->Pointers.size();
SmallVector<Value* , 2> Starts;
SmallVector<Value* , 2> Ends;
SCEVExpander Exp(*SE, "induction");
// Use this type for pointer arithmetic.
Type* PtrArithTy = PtrRtCheck->Pointers[0]->getType();
for (unsigned i=0; i < NumPointers; ++i) {
Value *Ptr = PtrRtCheck->Pointers[i];
const SCEV *Sc = SE->getSCEV(Ptr);
if (SE->isLoopInvariant(Sc, OrigLoop)) {
DEBUG(dbgs() << "LV1: Adding RT check for a loop invariant ptr:" <<
*Ptr <<"\n");
Starts.push_back(Ptr);
Ends.push_back(Ptr);
} else {
DEBUG(dbgs() << "LV: Adding RT check for range:" << *Ptr <<"\n");
Value *Start = Exp.expandCodeFor(PtrRtCheck->Starts[i],
PtrArithTy, Loc);
Value *End = Exp.expandCodeFor(PtrRtCheck->Ends[i], PtrArithTy, Loc);
Starts.push_back(Start);
Ends.push_back(End);
}
}
for (unsigned i = 0; i < NumPointers; ++i) {
for (unsigned j = i+1; j < NumPointers; ++j) {
Value *Cmp0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE,
Starts[i], Ends[j], "bound0", Loc);
Value *Cmp1 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE,
Starts[j], Ends[i], "bound1", Loc);
Value *IsConflict = BinaryOperator::Create(Instruction::And, Cmp0, Cmp1,
"found.conflict", Loc);
if (MemoryRuntimeCheck) {
MemoryRuntimeCheck = BinaryOperator::Create(Instruction::Or,
MemoryRuntimeCheck,
IsConflict,
"conflict.rdx", Loc);
} else {
MemoryRuntimeCheck = IsConflict;
}
}
}
return MemoryRuntimeCheck;
}
void
SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
/*
@ -791,56 +856,7 @@ SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
StartIdx,
"cmp.zero", Loc);
LoopVectorizationLegality::RuntimePointerCheck *PtrRtCheck =
Legal->getRuntimePointerCheck();
Value *MemoryRuntimeCheck = 0;
if (PtrRtCheck->Need) {
unsigned NumPointers = PtrRtCheck->Pointers.size();
SmallVector<Value* , 2> Starts;
SmallVector<Value* , 2> Ends;
// Use this type for pointer arithmetic.
Type* PtrArithTy = PtrRtCheck->Pointers[0]->getType();
for (unsigned i=0; i < NumPointers; ++i) {
Value *Ptr = PtrRtCheck->Pointers[i];
const SCEV *Sc = SE->getSCEV(Ptr);
if (SE->isLoopInvariant(Sc, OrigLoop)) {
DEBUG(dbgs() << "LV1: Adding RT check for a loop invariant ptr:" <<
*Ptr <<"\n");
Starts.push_back(Ptr);
Ends.push_back(Ptr);
} else {
DEBUG(dbgs() << "LV: Adding RT check for range:" << *Ptr <<"\n");
Value *Start = Exp.expandCodeFor(PtrRtCheck->Starts[i],
PtrArithTy, Loc);
Value *End = Exp.expandCodeFor(PtrRtCheck->Ends[i], PtrArithTy, Loc);
Starts.push_back(Start);
Ends.push_back(End);
}
}
for (unsigned i = 0; i < NumPointers; ++i) {
for (unsigned j = i+1; j < NumPointers; ++j) {
Value *Cmp0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE,
Starts[i], Ends[j], "bound0", Loc);
Value *Cmp1 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE,
Starts[j], Ends[i], "bound1", Loc);
Value *IsConflict = BinaryOperator::Create(Instruction::And, Cmp0, Cmp1,
"found.conflict", Loc);
if (MemoryRuntimeCheck) {
MemoryRuntimeCheck = BinaryOperator::Create(Instruction::Or,
MemoryRuntimeCheck,
IsConflict,
"conflict.rdx", Loc);
} else {
MemoryRuntimeCheck = IsConflict;
}
}
}
}// end of need-runtime-check code.
Value *MemoryRuntimeCheck = addRuntimeCheck(Legal, Loc);
// If we are using memory runtime checks, include them in.
if (MemoryRuntimeCheck) {