Teach the X86 instruction selection to do some heroic transforms to

detect a pattern which can be implemented with a small 'shl' embedded in
the addressing mode scale. This happens in real code as follows:

  unsigned x = my_accelerator_table[input >> 11];

Here we have some lookup table that we look into using the high bits of
'input'. Each entity in the table is 4-bytes, which means this
implicitly gets turned into (once lowered out of a GEP):

  *(unsigned*)((char*)my_accelerator_table + ((input >> 11) << 2));

The shift right followed by a shift left is canonicalized to a smaller
shift right and masking off the low bits. That hides the shift right
which x86 has an addressing mode designed to support. We now detect
masks of this form, and produce the longer shift right followed by the
proper addressing mode. In addition to saving a (rather large)
instruction, this also reduces stalls in Intel chips on benchmarks I've
measured.

In order for all of this to work, one part of the DAG needs to be
canonicalized *still further* than it currently is. This involves
removing pointless 'trunc' nodes between a zextload and a zext. Without
that, we end up generating spurious masks and hiding the pattern.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@147936 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Chandler Carruth 2012-01-11 08:41:08 +00:00
parent 88c5c42c5c
commit f103b3d1b9
3 changed files with 213 additions and 0 deletions

View File

@ -4254,6 +4254,29 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT,
N0.getOperand(0));
// fold (zext (truncate x)) -> (zext x) or
// (zext (truncate x)) -> (truncate x)
// This is valid when the truncated bits of x are already zero.
// FIXME: We should extend this to work for vectors too.
if (N0.getOpcode() == ISD::TRUNCATE && !VT.isVector()) {
SDValue Op = N0.getOperand(0);
APInt TruncatedBits
= APInt::getBitsSet(Op.getValueSizeInBits(),
N0.getValueSizeInBits(),
std::min(Op.getValueSizeInBits(),
VT.getSizeInBits()));
APInt KnownZero, KnownOne;
DAG.ComputeMaskedBits(Op, TruncatedBits, KnownZero, KnownOne);
if (TruncatedBits == KnownZero) {
if (VT.bitsGT(Op.getValueType()))
return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT, Op);
if (VT.bitsLT(Op.getValueType()))
return DAG.getNode(ISD::TRUNCATE, N->getDebugLoc(), VT, Op);
return Op;
}
}
// fold (zext (truncate (load x))) -> (zext (smaller load x))
// fold (zext (truncate (srl (load x), c))) -> (zext (small load (x+c/n)))
if (N0.getOpcode() == ISD::TRUNCATE) {

View File

@ -725,6 +725,140 @@ bool X86DAGToDAGISel::MatchAddress(SDValue N, X86ISelAddressMode &AM) {
return false;
}
// Implement some heroics to detect shifts of masked values where the mask can
// be replaced by extending the shift and undoing that in the addressing mode
// scale. Patterns such as (shl (srl x, c1), c2) are canonicalized into (and
// (srl x, SHIFT), MASK) by DAGCombines that don't know the shl can be done in
// the addressing mode. This results in code such as:
//
// int f(short *y, int *lookup_table) {
// ...
// return *y + lookup_table[*y >> 11];
// }
//
// Turning into:
// movzwl (%rdi), %eax
// movl %eax, %ecx
// shrl $11, %ecx
// addl (%rsi,%rcx,4), %eax
//
// Instead of:
// movzwl (%rdi), %eax
// movl %eax, %ecx
// shrl $9, %ecx
// andl $124, %rcx
// addl (%rsi,%rcx), %eax
//
static bool FoldMaskAndShiftToScale(SelectionDAG &DAG, SDValue N,
X86ISelAddressMode &AM) {
// Scale must not be used already.
if (AM.IndexReg.getNode() != 0 || AM.Scale != 1) return true;
SDValue Shift = N;
SDValue And = N.getOperand(0);
if (N.getOpcode() != ISD::SRL)
std::swap(Shift, And);
if (Shift.getOpcode() != ISD::SRL || And.getOpcode() != ISD::AND ||
!Shift.hasOneUse() ||
!isa<ConstantSDNode>(Shift.getOperand(1)) ||
!isa<ConstantSDNode>(And.getOperand(1)))
return true;
SDValue X = (N == Shift ? And.getOperand(0) : Shift.getOperand(0));
// We only handle up to 64-bit values here as those are what matter for
// addressing mode optimizations.
if (X.getValueSizeInBits() > 64) return true;
uint64_t Mask = And.getConstantOperandVal(1);
unsigned ShiftAmt = Shift.getConstantOperandVal(1);
unsigned MaskLZ = CountLeadingZeros_64(Mask);
unsigned MaskTZ = CountTrailingZeros_64(Mask);
// The amount of shift we're trying to fit into the addressing mode is taken
// from the trailing zeros of the mask. If the mask is pre-shift, we subtract
// the shift amount.
int AMShiftAmt = MaskTZ - (N == Shift ? ShiftAmt : 0);
// There is nothing we can do here unless the mask is removing some bits.
// Also, the addressing mode can only represent shifts of 1, 2, or 3 bits.
if (AMShiftAmt <= 0 || AMShiftAmt > 3) return true;
// We also need to ensure that mask is a continuous run of bits.
if (CountTrailingOnes_64(Mask >> MaskTZ) + MaskTZ + MaskLZ != 64) return true;
// Scale the leading zero count down based on the actual size of the value.
// Also scale it down based on the size of the shift if it was applied
// before the mask.
MaskLZ -= (64 - X.getValueSizeInBits()) + (N == Shift ? 0 : ShiftAmt);
// The final check is to ensure that any masked out high bits of X are
// already known to be zero. Otherwise, the mask has a semantic impact
// other than masking out a couple of low bits. Unfortunately, because of
// the mask, zero extensions will be removed from operands in some cases.
// This code works extra hard to look through extensions because we can
// replace them with zero extensions cheaply if necessary.
bool ReplacingAnyExtend = false;
if (X.getOpcode() == ISD::ANY_EXTEND) {
unsigned ExtendBits =
X.getValueSizeInBits() - X.getOperand(0).getValueSizeInBits();
// Assume that we'll replace the any-extend with a zero-extend, and
// narrow the search to the extended value.
X = X.getOperand(0);
MaskLZ = ExtendBits > MaskLZ ? 0 : MaskLZ - ExtendBits;
ReplacingAnyExtend = true;
}
APInt MaskedHighBits = APInt::getHighBitsSet(X.getValueSizeInBits(),
MaskLZ);
APInt KnownZero, KnownOne;
DAG.ComputeMaskedBits(X, MaskedHighBits, KnownZero, KnownOne);
if (MaskedHighBits != KnownZero) return true;
// We've identified a pattern that can be transformed into a single shift
// and an addressing mode. Make it so.
EVT VT = N.getValueType();
if (ReplacingAnyExtend) {
assert(X.getValueType() != VT);
// We looked through an ANY_EXTEND node, insert a ZERO_EXTEND.
SDValue NewX = DAG.getNode(ISD::ZERO_EXTEND, X.getDebugLoc(), VT, X);
if (NewX.getNode()->getNodeId() == -1 ||
NewX.getNode()->getNodeId() > N.getNode()->getNodeId()) {
DAG.RepositionNode(N.getNode(), NewX.getNode());
NewX.getNode()->setNodeId(N.getNode()->getNodeId());
}
X = NewX;
}
DebugLoc DL = N.getDebugLoc();
SDValue NewSRLAmt = DAG.getConstant(ShiftAmt + AMShiftAmt, MVT::i8);
SDValue NewSRL = DAG.getNode(ISD::SRL, DL, VT, X, NewSRLAmt);
SDValue NewSHLAmt = DAG.getConstant(AMShiftAmt, MVT::i8);
SDValue NewSHL = DAG.getNode(ISD::SHL, DL, VT, NewSRL, NewSHLAmt);
if (NewSRLAmt.getNode()->getNodeId() == -1 ||
NewSRLAmt.getNode()->getNodeId() > N.getNode()->getNodeId()) {
DAG.RepositionNode(N.getNode(), NewSRLAmt.getNode());
NewSRLAmt.getNode()->setNodeId(N.getNode()->getNodeId());
}
if (NewSRL.getNode()->getNodeId() == -1 ||
NewSRL.getNode()->getNodeId() > N.getNode()->getNodeId()) {
DAG.RepositionNode(N.getNode(), NewSRL.getNode());
NewSRL.getNode()->setNodeId(N.getNode()->getNodeId());
}
if (NewSHLAmt.getNode()->getNodeId() == -1 ||
NewSHLAmt.getNode()->getNodeId() > N.getNode()->getNodeId()) {
DAG.RepositionNode(N.getNode(), NewSHLAmt.getNode());
NewSHLAmt.getNode()->setNodeId(N.getNode()->getNodeId());
}
if (NewSHL.getNode()->getNodeId() == -1 ||
NewSHL.getNode()->getNodeId() > N.getNode()->getNodeId()) {
DAG.RepositionNode(N.getNode(), NewSHL.getNode());
NewSHL.getNode()->setNodeId(N.getNode()->getNodeId());
}
DAG.ReplaceAllUsesWith(N, NewSHL);
AM.Scale = 1 << AMShiftAmt;
AM.IndexReg = NewSRL;
return false;
}
bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
unsigned Depth) {
DebugLoc dl = N.getDebugLoc();
@ -814,6 +948,13 @@ bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
break;
}
case ISD::SRL:
// Try to fold the mask and shift into the scale, and return false if we
// succeed.
if (!FoldMaskAndShiftToScale(*CurDAG, N, AM))
return false;
break;
case ISD::SMUL_LOHI:
case ISD::UMUL_LOHI:
// A mul_lohi where we need the low part can be folded as a plain multiply.
@ -1047,6 +1188,11 @@ bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
}
}
// Try to fold the mask and shift into the scale, and return false if we
// succeed.
if (!FoldMaskAndShiftToScale(*CurDAG, N, AM))
return false;
// Handle "(X << C1) & C2" as "(X & (C2>>C1)) << C1" if safe and if this
// allows us to fold the shift into this addressing mode.
if (Shift.getOpcode() != ISD::SHL) break;

View File

@ -31,3 +31,47 @@ entry:
%tmp9 = load i32* %tmp78
ret i32 %tmp9
}
define i32 @t3(i16* %i.ptr, i32* %arr) {
; This case is tricky. The lshr followed by a gep will produce a lshr followed
; by an and to remove the low bits. This can be simplified by doing the lshr by
; a greater constant and using the addressing mode to scale the result back up.
; To make matters worse, because of the two-phase zext of %i and their reuse in
; the function, the DAG can get confusing trying to re-use both of them and
; prevent easy analysis of the mask in order to match this.
; CHECK: t3:
; CHECK-NOT: and
; CHECK: shrl
; CHECK: addl (%{{...}},%{{...}},4),
; CHECK: ret
entry:
%i = load i16* %i.ptr
%i.zext = zext i16 %i to i32
%index = lshr i32 %i.zext, 11
%val.ptr = getelementptr inbounds i32* %arr, i32 %index
%val = load i32* %val.ptr
%sum = add i32 %val, %i.zext
ret i32 %sum
}
define i32 @t4(i16* %i.ptr, i32* %arr) {
; A version of @t3 that has more zero extends and more re-use of intermediate
; values. This exercise slightly different bits of canonicalization.
; CHECK: t4:
; CHECK-NOT: and
; CHECK: shrl
; CHECK: addl (%{{...}},%{{...}},4),
; CHECK: ret
entry:
%i = load i16* %i.ptr
%i.zext = zext i16 %i to i32
%index = lshr i32 %i.zext, 11
%index.zext = zext i32 %index to i64
%val.ptr = getelementptr inbounds i32* %arr, i64 %index.zext
%val = load i32* %val.ptr
%sum.1 = add i32 %val, %i.zext
%sum.2 = add i32 %sum.1, %index
ret i32 %sum.2
}