diff --git a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index ac3e7c4d74a..b8529e174ca 100644 --- a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -165,6 +165,10 @@ class ConstantOffsetExtractor { void ComputeKnownBits(Value *V, APInt &KnownOne, APInt &KnownZero) const; /// Finds the first use of Used in U. Returns -1 if not found. static unsigned FindFirstUse(User *U, Value *Used); + /// Returns whether OPC (sext or zext) can be distributed to the operands of + /// BO. e.g., sext can be distributed to the operands of an "add nsw" because + /// sext (add nsw a, b) == add nsw (sext a), (sext b). + static bool Distributable(unsigned OPC, BinaryOperator *BO); /// The path from the constant offset to the old GEP index. e.g., if the GEP /// index is "a * b + (c + 5)". After running function find, UserChain[0] will @@ -223,6 +227,25 @@ FunctionPass *llvm::createSeparateConstOffsetFromGEPPass() { return new SeparateConstOffsetFromGEP(); } +bool ConstantOffsetExtractor::Distributable(unsigned OPC, BinaryOperator *BO) { + assert(OPC == Instruction::SExt || OPC == Instruction::ZExt); + + // sext (add/sub nsw A, B) == add/sub nsw (sext A), (sext B) + // zext (add/sub nuw A, B) == add/sub nuw (zext A), (zext B) + if (BO->getOpcode() == Instruction::Add || + BO->getOpcode() == Instruction::Sub) { + return (OPC == Instruction::SExt && BO->hasNoSignedWrap()) || + (OPC == Instruction::ZExt && BO->hasNoUnsignedWrap()); + } + + // sext/zext (and/or/xor A, B) == and/or/xor (sext/zext A), (sext/zext B) + // -instcombine also leverages this invariant to do the reverse + // transformation to reduce integer casts. + return BO->getOpcode() == Instruction::And || + BO->getOpcode() == Instruction::Or || + BO->getOpcode() == Instruction::Xor; +} + int64_t ConstantOffsetExtractor::findInEitherOperand(User *U, bool IsSub) { assert(U->getNumOperands() == 2); int64_t ConstantOffset = find(U->getOperand(0)); @@ -273,21 +296,14 @@ int64_t ConstantOffsetExtractor::find(Value *V) { ConstantOffset = findInEitherOperand(U, false); break; } - case Instruction::SExt: { - // For safety, we trace into sext only when its operand is marked - // "nsw" because xxx.nsw guarantees no signed wrap. e.g., we can safely - // transform "sext (add nsw a, 5)" into "add nsw (sext a), 5". - if (BinaryOperator *BO = dyn_cast(U->getOperand(0))) { - if (BO->hasNoSignedWrap()) - ConstantOffset = find(U->getOperand(0)); - } - break; - } + case Instruction::SExt: case Instruction::ZExt: { - // Similarly, we trace into zext only when its operand is marked with - // "nuw" because zext (add nuw a, b) == add nuw (zext a), (zext b). + // We trace into sext/zext if the operator can be distributed to its + // operand. e.g., we can transform into "sext (add nsw a, 5)" and + // extract constant 5, because + // sext (add nsw a, 5) == add nsw (sext a), 5 if (BinaryOperator *BO = dyn_cast(U->getOperand(0))) { - if (BO->hasNoUnsignedWrap()) + if (Distributable(O->getOpcode(), BO)) ConstantOffset = find(U->getOperand(0)); } break; diff --git a/test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep.ll b/test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep.ll index 320af5fd613..42136d2b657 100644 --- a/test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep.ll +++ b/test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep.ll @@ -57,6 +57,25 @@ define float* @ext_add_no_overflow(i64 %a, i32 %b, i64 %c, i32 %d) { ; CHECK: [[BASE_PTR:%[0-9]+]] = getelementptr [32 x [32 x float]]* @float_2d_array, i64 0, i64 %{{[0-9]+}}, i64 %{{[0-9]+}} ; CHECK: getelementptr float* [[BASE_PTR]], i64 33 +; Similar to @ext_add_no_overflow, we should be able to trace into sext/zext if +; its operand is an "or" instruction. +define float* @ext_or(i64 %a, i32 %b) { +entry: + %b1 = shl i32 %b, 2 + %b2 = or i32 %b1, 1 + %b3 = or i32 %b1, 2 + %b2.ext = sext i32 %b2 to i64 + %b3.ext = sext i32 %b3 to i64 + %i = add i64 %a, %b2.ext + %j = add i64 %a, %b3.ext + %p = getelementptr inbounds [32 x [32 x float]]* @float_2d_array, i64 0, i64 %i, i64 %j + ret float* %p +} +; CHECK-LABEL: @ext_or +; CHECK: [[BASE_PTR:%[0-9]+]] = getelementptr [32 x [32 x float]]* @float_2d_array, i64 0, i64 %{{[0-9]+}}, i64 %{{[0-9]+}} +; CHECK: [[BASE_INT:%[0-9]+]] = ptrtoint float* [[BASE_PTR]] to i64 +; CHECK: add i64 [[BASE_INT]], 136 + ; We should treat "or" with no common bits (%k) as "add", and leave "or" with ; potentially common bits (%l) as is. define float* @or(i64 %i) {