From 5c34e08b9fff9d4df2421e4f41ff15b85f638dd1 Mon Sep 17 00:00:00 2001 From: Stephen Lin Date: Sat, 20 Apr 2013 04:27:51 +0000 Subject: [PATCH] Allow tail call opportunity detection through nested and/or multiple iterations of extractelement/insertelement indirection git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@179924 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/CodeGen/Analysis.cpp | 211 ++++++++++++++++++++------------ test/CodeGen/X86/tailcall-64.ll | 88 +++++++++++++ 2 files changed, 220 insertions(+), 79 deletions(-) diff --git a/lib/CodeGen/Analysis.cpp b/lib/CodeGen/Analysis.cpp index dd7282c0ad9..9723f8080c8 100644 --- a/lib/CodeGen/Analysis.cpp +++ b/lib/CodeGen/Analysis.cpp @@ -201,62 +201,135 @@ ISD::CondCode llvm::getICmpCondCode(ICmpInst::Predicate Pred) { } } - -/// getNoopInput - If V is a noop (i.e., lowers to no machine code), look -/// through it (and any transitive noop operands to it) and return its input -/// value. This is used to determine if a tail call can be formed. -/// -static const Value *getNoopInput(const Value *V, const TargetLowering &TLI) { - // If V is not an instruction, it can't be looked through. - const Instruction *I = dyn_cast(V); - if (I == 0 || !I->hasOneUse() || I->getNumOperands() == 0) return V; - - Value *Op = I->getOperand(0); - - // Look through truly no-op truncates. - if (isa(I) && - TLI.isTruncateFree(I->getOperand(0)->getType(), I->getType())) - return getNoopInput(I->getOperand(0), TLI); - - // Look through truly no-op bitcasts. - if (isa(I)) { - // No type change at all. - if (Op->getType() == I->getType()) - return getNoopInput(Op, TLI); - - // Pointer to pointer cast. - if (Op->getType()->isPointerTy() && I->getType()->isPointerTy()) - return getNoopInput(Op, TLI); - - if (isa(Op->getType()) && isa(I->getType()) && - TLI.isTypeLegal(EVT::getEVT(Op->getType())) && - TLI.isTypeLegal(EVT::getEVT(I->getType()))) - return getNoopInput(Op, TLI); - } - - // Look through inttoptr. - if (isa(I) && !isa(I->getType())) { - // Make sure this isn't a truncating or extending cast. We could support - // this eventually, but don't bother for now. - if (TLI.getPointerTy().getSizeInBits() == - cast(Op->getType())->getBitWidth()) - return getNoopInput(Op, TLI); - } - - // Look through ptrtoint. - if (isa(I) && !isa(I->getType())) { - // Make sure this isn't a truncating or extending cast. We could support - // this eventually, but don't bother for now. - if (TLI.getPointerTy().getSizeInBits() == - cast(I->getType())->getBitWidth()) - return getNoopInput(Op, TLI); - } - - - // Otherwise it's not something we can look through. - return V; +static bool isNoopBitcast(Type *T1, Type *T2, + const TargetLowering& TLI) { + return T1 == T2 || (T1->isPointerTy() && T2->isPointerTy()) || + (isa(T1) && isa(T2) && + TLI.isTypeLegal(EVT::getEVT(T1)) && TLI.isTypeLegal(EVT::getEVT(T2))); } +/// sameNoopInput - Return true if V1 == V2, else if either V1 or V2 is a noop +/// (i.e., lowers to no machine code), look through it (and any transitive noop +/// operands to it) and check if it has the same noop input value. This is +/// used to determine if a tail call can be formed. +static bool sameNoopInput(const Value *V1, const Value *V2, + SmallVectorImpl &Els1, + SmallVectorImpl &Els2, + const TargetLowering &TLI) { + using std::swap; + bool swapParity = false; + bool equalEls = Els1 == Els2; + while (true) { + if ((equalEls && V1 == V2) || isa(V1) || isa(V2)) { + if (swapParity) + // Revert to original Els1 and Els2 to avoid confusing recursive calls + swap(Els1, Els2); + return true; + } + + // Try to look through V1; if V1 is not an instruction, it can't be looked + // through. + const Instruction *I = dyn_cast(V1); + const Value *NoopInput = 0; + if (I != 0 && I->getNumOperands() > 0) { + Value *Op = I->getOperand(0); + if (isa(I)) { + // Look through truly no-op truncates. + if (TLI.isTruncateFree(Op->getType(), I->getType())) + NoopInput = Op; + } else if (isa(I)) { + // Look through truly no-op bitcasts. + if (isNoopBitcast(Op->getType(), I->getType(), TLI)) + NoopInput = Op; + } else if (isa(I)) { + // Look through getelementptr + if (cast(I)->hasAllZeroIndices()) + NoopInput = Op; + } else if (isa(I)) { + // Look through inttoptr. + // Make sure this isn't a truncating or extending cast. We could + // support this eventually, but don't bother for now. + if (!isa(I->getType()) && + TLI.getPointerTy().getSizeInBits() == + cast(Op->getType())->getBitWidth()) + NoopInput = Op; + } else if (isa(I)) { + // Look through ptrtoint. + // Make sure this isn't a truncating or extending cast. We could + // support this eventually, but don't bother for now. + if (!isa(I->getType()) && + TLI.getPointerTy().getSizeInBits() == + cast(I->getType())->getBitWidth()) + NoopInput = Op; + } + } + + if (NoopInput) { + V1 = NoopInput; + continue; + } + + // If we already swapped, avoid infinite loop + if (swapParity) + break; + + // Otherwise, swap V1<->V2, Els1<->Els2 + swap(V1, V2); + swap(Els1, Els2); + swapParity = !swapParity; + } + + for (unsigned n = 0; n < 2; ++n) { + if (isa(V1)) { + if (isa(V1->getType())) { + // Look through insertvalue + unsigned i, e; + for (i = 0, e = cast(V1->getType())->getNumElements(); + i != e; ++i) { + const Value *InScalar = FindInsertedValue(const_cast(V1), i); + if (InScalar == 0) + break; + Els1.push_back(i); + if (!sameNoopInput(InScalar, V2, Els1, Els2, TLI)) { + Els1.pop_back(); + break; + } + Els1.pop_back(); + } + if (i == e) { + if (swapParity) + swap(Els1, Els2); + return true; + } + } + } else if (!Els1.empty() && isa(V1)) { + const ExtractValueInst *EVI = cast(V1); + unsigned i = Els1.back(); + // If the scalar value being inserted is an extractvalue of the right + // index from the call, then everything is good. + if (isa(EVI->getOperand(0)->getType()) && + EVI->getNumIndices() == 1 && EVI->getIndices()[0] == i) { + // Look through extractvalue + Els1.pop_back(); + if (sameNoopInput(EVI->getOperand(0), V2, Els1, Els2, TLI)) { + Els1.push_back(i); + if (swapParity) + swap(Els1, Els2); + return true; + } + Els1.push_back(i); + } + } + + swap(V1, V2); + swap(Els1, Els2); + swapParity = !swapParity; + } + + if (swapParity) + swap(Els1, Els2); + return false; +} /// Test if the given instruction is in a position to be optimized /// with a tail-call. This roughly means that it's in a block with @@ -264,7 +337,8 @@ static const Value *getNoopInput(const Value *V, const TargetLowering &TLI) { /// between it and the return. /// /// This function only tests target-independent requirements. -bool llvm::isInTailCallPosition(ImmutableCallSite CS,const TargetLowering &TLI){ +bool llvm::isInTailCallPosition(ImmutableCallSite CS, + const TargetLowering &TLI) { const Instruction *I = CS.getInstruction(); const BasicBlock *ExitBB = I->getParent(); const TerminatorInst *Term = ExitBB->getTerminator(); @@ -322,28 +396,7 @@ bool llvm::isInTailCallPosition(ImmutableCallSite CS,const TargetLowering &TLI){ CallerAttrs.hasAttribute(AttributeSet::ReturnIndex, Attribute::SExt)) return false; - // Otherwise, make sure the unmodified return value of I is the return value. - // We handle two cases: multiple return values + scalars. - Value *RetVal = Ret->getOperand(0); - if (!isa(RetVal) || !isa(RetVal->getType())) - // Handle scalars first. - return getNoopInput(Ret->getOperand(0), TLI) == I; - - // If this is an aggregate return, look through the insert/extract values and - // see if each is transparent. - for (unsigned i = 0, e =cast(RetVal->getType())->getNumElements(); - i != e; ++i) { - const Value *InScalar = FindInsertedValue(RetVal, i); - if (InScalar == 0) return false; - InScalar = getNoopInput(InScalar, TLI); - - // If the scalar value being inserted is an extractvalue of the right index - // from the call, then everything is good. - const ExtractValueInst *EVI = dyn_cast(InScalar); - if (EVI == 0 || EVI->getOperand(0) != I || EVI->getNumIndices() != 1 || - EVI->getIndices()[0] != i) - return false; - } - - return true; + // Otherwise, make sure the return value and I have the same value + SmallVector Els1, Els2; + return sameNoopInput(Ret->getOperand(0), I, Els1, Els2, TLI); } diff --git a/test/CodeGen/X86/tailcall-64.ll b/test/CodeGen/X86/tailcall-64.ll index ecc253ba587..eb2fef01155 100644 --- a/test/CodeGen/X86/tailcall-64.ll +++ b/test/CodeGen/X86/tailcall-64.ll @@ -50,7 +50,16 @@ define {i64, i64} @test_pair_trivial() { ; CHECK: test_pair_trivial: ; CHECK: jmp _testp ## TAILCALL +define {i64, i64} @test_pair_notail() { + %A = tail call i64 @testi() + %b = insertvalue {i64, i64} undef, i64 %A, 0 + %c = insertvalue {i64, i64} %b, i64 %A, 1 + + ret { i64, i64} %c +} +; CHECK: test_pair_notail: +; CHECK-NOT: jmp _testi define {i64, i64} @test_pair_trivial_extract() { %A = tail call { i64, i64} @testp() @@ -66,6 +75,20 @@ define {i64, i64} @test_pair_trivial_extract() { ; CHECK: test_pair_trivial_extract: ; CHECK: jmp _testp ## TAILCALL +define {i64, i64} @test_pair_notail_extract() { + %A = tail call { i64, i64} @testp() + %x = extractvalue { i64, i64} %A, 0 + %y = extractvalue { i64, i64} %A, 1 + + %b = insertvalue {i64, i64} undef, i64 %y, 0 + %c = insertvalue {i64, i64} %b, i64 %x, 1 + + ret { i64, i64} %c +} + +; CHECK: test_pair_notail_extract: +; CHECK-NOT: jmp _testp + define {i8*, i64} @test_pair_conv_extract() { %A = tail call { i64, i64} @testp() %x = extractvalue { i64, i64} %A, 0 @@ -82,7 +105,72 @@ define {i8*, i64} @test_pair_conv_extract() { ; CHECK: test_pair_conv_extract: ; CHECK: jmp _testp ## TAILCALL +define {i64, i64} @test_pair_multiple_extract() { + %A = tail call { i64, i64} @testp() + %x = extractvalue { i64, i64} %A, 0 + %y = extractvalue { i64, i64} %A, 1 + + %b = insertvalue {i64, i64} undef, i64 %x, 0 + %c = insertvalue {i64, i64} %b, i64 %y, 1 + %x1 = extractvalue { i64, i64} %b, 0 + %y1 = extractvalue { i64, i64} %c, 1 + + %d = insertvalue {i64, i64} undef, i64 %x1, 0 + %e = insertvalue {i64, i64} %b, i64 %y1, 1 + + ret { i64, i64} %e +} + +; CHECK: test_pair_multiple_extract: +; CHECK: jmp _testp ## TAILCALL + +define {i64, i64} @test_pair_undef_extract() { + %A = tail call { i64, i64} @testp() + %x = extractvalue { i64, i64} %A, 0 + + %b = insertvalue {i64, i64} undef, i64 %x, 0 + + ret { i64, i64} %b +} + +; CHECK: test_pair_undef_extract: +; CHECK: jmp _testp ## TAILCALL + +declare { i64, { i32, i32 } } @testn() + +define {i64, {i32, i32}} @test_nest() { + %A = tail call { i64, { i32, i32 } } @testn() + %x = extractvalue { i64, { i32, i32}} %A, 0 + %y = extractvalue { i64, { i32, i32}} %A, 1 + %y1 = extractvalue { i32, i32} %y, 0 + %y2 = extractvalue { i32, i32} %y, 1 + + %b = insertvalue {i64, {i32, i32}} undef, i64 %x, 0 + %c1 = insertvalue {i32, i32} undef, i32 %y1, 0 + %c2 = insertvalue {i32, i32} %c1, i32 %y2, 1 + %c = insertvalue {i64, {i32, i32}} %b, {i32, i32} %c2, 1 + + ret { i64, { i32, i32}} %c +} + +; CHECK: test_nest: +; CHECK: jmp _testn ## TAILCALL + +%struct.A = type { i32 } +%struct.B = type { %struct.A, i32 } + +declare %struct.B* @testu() + +define %struct.A* @test_upcast() { +entry: + %A = tail call %struct.B* @testu() + %x = getelementptr inbounds %struct.B* %A, i32 0, i32 0 + ret %struct.A* %x +} + +; CHECK: test_upcast: +; CHECK: jmp _testu ## TAILCALL ; PR13006 define { i64, i64 } @crash(i8* %this) {