From 8c34cd287a07d80491cde1e86da98568675dbe47 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 5 Oct 2008 02:13:19 +0000 Subject: [PATCH] rewrite bswap matching to be more general, allowing arbitrary shifting and masking inside a bswap expr. This allows it to handle the cases from PR2842, which involve the intermediate 'or' expressions being shifted, not just the input value. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@57095 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../Scalar/InstructionCombining.cpp | 198 +++++++++++------- test/Transforms/InstCombine/bswap.ll | 17 +- 2 files changed, 137 insertions(+), 78 deletions(-) diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 95ed49d5262..6f9893cc8fa 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -3887,88 +3887,130 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return Changed ? &I : 0; } -/// CollectBSwapParts - Look to see if the specified value defines a single byte -/// in the result. If it does, and if the specified byte hasn't been filled in -/// yet, fill it in and return false. -static bool CollectBSwapParts(Value *V, SmallVector &ByteValues) { - Instruction *I = dyn_cast(V); - if (I == 0) return true; - - // If this is an or instruction, it is an inner node of the bswap. - if (I->getOpcode() == Instruction::Or) - return CollectBSwapParts(I->getOperand(0), ByteValues) || - CollectBSwapParts(I->getOperand(1), ByteValues); - - uint32_t BitWidth = I->getType()->getPrimitiveSizeInBits(); - // If this is a shift by a constant int, and it is "24", then its operand - // defines a byte. We only handle unsigned types here. - if (I->isShift() && isa(I->getOperand(1))) { - // Not shifting the entire input by N-1 bytes? - if (cast(I->getOperand(1))->getLimitedValue(BitWidth) != - 8*(ByteValues.size()-1)) - return true; - - unsigned DestNo; - if (I->getOpcode() == Instruction::Shl) { - // X << 24 defines the top byte with the lowest of the input bytes. - DestNo = ByteValues.size()-1; - } else if (I->getOpcode() == Instruction::LShr) { - // X >>u 24 defines the low byte with the highest of the input bytes. - DestNo = 0; - } else { - // Arithmetic shift right may have the top bits set. - return true; +/// CollectBSwapParts - Analyze the specified subexpression and see if it is +/// capable of providing pieces of a bswap. The subexpression provides pieces +/// of a bswap if it is proven that each of the non-zero bytes in the output of +/// the expression came from the corresponding "byte swapped" byte in some other +/// value. For example, if the current subexpression is "(shl i32 %X, 24)" then +/// we know that the expression deposits the low byte of %X into the high byte +/// of the bswap result and that all other bytes are zero. This expression is +/// accepted, the high byte of ByteValues is set to X to indicate a correct +/// match. +/// +/// This function returns true if the match was unsuccessful and false if so. +/// On entry to the function the "OverallLeftShift" is a signed integer value +/// indicating the number of bytes that the subexpression is later shifted. For +/// example, if the expression is later right shifted by 16 bits, the +/// OverallLeftShift value would be -2 on entry. This is used to specify which +/// byte of ByteValues is actually being set. +/// +/// Similarly, ByteMask is a bitmask where a bit is clear if its corresponding +/// byte is masked to zero by a user. For example, in (X & 255), X will be +/// processed with a bytemask of 1. Because bytemask is 32-bits, this limits +/// this function to working on up to 32-byte (256 bit) values. ByteMask is +/// always in the local (OverallLeftShift) coordinate space. +/// +static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask, + SmallVector &ByteValues) { + if (Instruction *I = dyn_cast(V)) { + // If this is an or instruction, it may be an inner node of the bswap. + if (I->getOpcode() == Instruction::Or) { + return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, + ByteValues) || + CollectBSwapParts(I->getOperand(1), OverallLeftShift, ByteMask, + ByteValues); + } + + // If this is a logical shift by a constant multiple of 8, recurse with + // OverallLeftShift and ByteMask adjusted. + if (I->isLogicalShift() && isa(I->getOperand(1))) { + unsigned ShAmt = + cast(I->getOperand(1))->getLimitedValue(~0U); + // Ensure the shift amount is defined and of a byte value. + if ((ShAmt & 7) || (ShAmt > 8*ByteValues.size())) + return true; + + unsigned ByteShift = ShAmt >> 3; + if (I->getOpcode() == Instruction::Shl) { + // X << 2 -> collect(X, +2) + OverallLeftShift += ByteShift; + ByteMask >>= ByteShift; + } else { + // X >>u 2 -> collect(X, -2) + OverallLeftShift -= ByteShift; + ByteMask <<= ByteShift; + ByteMask &= (~0U >> 32-ByteValues.size()); + } + + if (OverallLeftShift >= (int)ByteValues.size()) return true; + if (OverallLeftShift <= -(int)ByteValues.size()) return true; + + return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, + ByteValues); + } + + // If this is a logical 'and' with a mask that clears bytes, clear the + // corresponding bytes in ByteMask. + if (I->getOpcode() == Instruction::And && + isa(I->getOperand(1))) { + // Scan every byte of the and mask, seeing if the byte is either 0 or 255. + unsigned NumBytes = ByteValues.size(); + APInt Byte(I->getType()->getPrimitiveSizeInBits(), 255); + const APInt &AndMask = cast(I->getOperand(1))->getValue(); + + for (unsigned i = 0; i != NumBytes; ++i, Byte <<= 8) { + // If this byte is masked out by a later operation, we don't care what + // the and mask is. + if ((ByteMask & (1 << i)) == 0) + continue; + + // If the AndMask is all zeros for this byte, clear the bit. + APInt MaskB = AndMask & Byte; + if (MaskB == 0) { + ByteMask &= ~(1U << i); + continue; + } + + // If the AndMask is not all ones for this byte, it's not a bytezap. + if (MaskB != Byte) + return true; + + // Otherwise, this byte is kept. + } + + return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, + ByteValues); } - - // If the destination byte value is already defined, the values are or'd - // together, which isn't a bswap (unless it's an or of the same bits). - if (ByteValues[DestNo] && ByteValues[DestNo] != I->getOperand(0)) - return true; - ByteValues[DestNo] = I->getOperand(0); - return false; } - // Otherwise, we can only handle and(shift X, imm), imm). Bail out of if we - // don't have this. - Value *Shift = 0, *ShiftLHS = 0; - ConstantInt *AndAmt = 0, *ShiftAmt = 0; - if (!match(I, m_And(m_Value(Shift), m_ConstantInt(AndAmt))) || - !match(Shift, m_Shift(m_Value(ShiftLHS), m_ConstantInt(ShiftAmt)))) - return true; - Instruction *SI = cast(Shift); - - // Make sure that the shift amount is by a multiple of 8 and isn't too big. - if (ShiftAmt->getLimitedValue(BitWidth) & 7 || - ShiftAmt->getLimitedValue(BitWidth) > 8*ByteValues.size()) - return true; + // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be + // the input value to the bswap. Some observations: 1) if more than one byte + // is demanded from this input, then it could not be successfully assembled + // into a byteswap. At least one of the two bytes would not be aligned with + // their ultimate destination. + if (!isPowerOf2_32(ByteMask)) return true; + unsigned InputByteNo = CountTrailingZeros_32(ByteMask); - // Turn 0xFF -> 0, 0xFF00 -> 1, 0xFF0000 -> 2, etc. - unsigned DestByte; - if (AndAmt->getValue().getActiveBits() > 64) - return true; - uint64_t AndAmtVal = AndAmt->getZExtValue(); - for (DestByte = 0; DestByte != ByteValues.size(); ++DestByte) - if (AndAmtVal == uint64_t(0xFF) << 8*DestByte) - break; - // Unknown mask for bswap. - if (DestByte == ByteValues.size()) return true; - - unsigned ShiftBytes = ShiftAmt->getZExtValue()/8; - unsigned SrcByte; - if (SI->getOpcode() == Instruction::Shl) - SrcByte = DestByte - ShiftBytes; - else - SrcByte = DestByte + ShiftBytes; - - // If the SrcByte isn't a bswapped value from the DestByte, reject it. - if (SrcByte != ByteValues.size()-DestByte-1) - return true; + // 2) The input and ultimate destinations must line up: if byte 3 of an i32 + // is demanded, it needs to go into byte 0 of the result. This means that the + // byte needs to be shifted until it lands in the right byte bucket. The + // shift amount depends on the position: if the byte is coming from the high + // part of the value (e.g. byte 3) then it must be shifted right. If from the + // low part, it must be shifted left. + unsigned DestByteNo = InputByteNo + OverallLeftShift; + if (InputByteNo < ByteValues.size()/2) { + if (ByteValues.size()-1-DestByteNo != InputByteNo) + return true; + } else { + if (ByteValues.size()-1-DestByteNo != InputByteNo) + return true; + } // If the destination byte value is already defined, the values are or'd // together, which isn't a bswap (unless it's an or of the same bits). - if (ByteValues[DestByte] && ByteValues[DestByte] != SI->getOperand(0)) + if (ByteValues[DestByteNo] && ByteValues[DestByteNo] != V) return true; - ByteValues[DestByte] = SI->getOperand(0); + ByteValues[DestByteNo] = V; return false; } @@ -3976,7 +4018,9 @@ static bool CollectBSwapParts(Value *V, SmallVector &ByteValues) { /// If so, insert the new bswap intrinsic and return it. Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { const IntegerType *ITy = dyn_cast(I.getType()); - if (!ITy || ITy->getBitWidth() % 16) + if (!ITy || ITy->getBitWidth() % 16 || + // ByteMask only allows up to 32-byte values. + ITy->getBitWidth() > 32*8) return 0; // Can only bswap pairs of bytes. Can't do vectors. /// ByteValues - For each byte of the result, we keep track of which value @@ -3985,8 +4029,8 @@ Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { ByteValues.resize(ITy->getBitWidth()/8); // Try to find all the pieces corresponding to the bswap. - if (CollectBSwapParts(I.getOperand(0), ByteValues) || - CollectBSwapParts(I.getOperand(1), ByteValues)) + uint32_t ByteMask = ~0U >> (32-ByteValues.size()); + if (CollectBSwapParts(&I, 0, ByteMask, ByteValues)) return 0; // Check to see if all of the bytes come from the same value. diff --git a/test/Transforms/InstCombine/bswap.ll b/test/Transforms/InstCombine/bswap.ll index 5db4e73a546..2ba718e5847 100644 --- a/test/Transforms/InstCombine/bswap.ll +++ b/test/Transforms/InstCombine/bswap.ll @@ -1,5 +1,5 @@ ; RUN: llvm-as < %s | opt -instcombine | llvm-dis | \ -; RUN: grep {call.*llvm.bswap} | count 5 +; RUN: grep {call.*llvm.bswap} | count 6 define i32 @test1(i32 %i) { %tmp1 = lshr i32 %i, 24 ; [#uses=1] @@ -55,3 +55,18 @@ define i16 @test5(i16 %a) { %retval = trunc i32 %tmp6.upgrd.4 to i16 ; [#uses=1] ret i16 %retval } + +; PR2842 +define i32 @test6(i32 %x) nounwind readnone { + %tmp = shl i32 %x, 16 ; [#uses=1] + %x.mask = and i32 %x, 65280 ; [#uses=1] + %tmp1 = lshr i32 %x, 16 ; [#uses=1] + %tmp2 = and i32 %tmp1, 255 ; [#uses=1] + %tmp3 = or i32 %x.mask, %tmp ; [#uses=1] + %tmp4 = or i32 %tmp3, %tmp2 ; [#uses=1] + %tmp5 = shl i32 %tmp4, 8 ; [#uses=1] + %tmp6 = lshr i32 %x, 24 ; [#uses=1] + %tmp7 = or i32 %tmp5, %tmp6 ; [#uses=1] + ret i32 %tmp7 +} +