R600: Fix LowerSDIV24

Use ComputeNumSignBits instead of checking for i8 / i16 which only
worked when AMDIL was lying about having legal i8 / i16.

If an integer is known to fit in 24-bits, we can
do division faster with float ops.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@213843 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Matt Arsenault 2014-07-24 06:59:20 +00:00
parent 8971050012
commit f303d037f2
2 changed files with 170 additions and 51 deletions

View File

@ -250,7 +250,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(TargetMachine &TM) :
const MVT ScalarIntVTs[] = { MVT::i32, MVT::i64 };
for (MVT VT : ScalarIntVTs) {
setOperationAction(ISD::SREM, VT, Expand);
setOperationAction(ISD::SDIV, VT, Expand);
setOperationAction(ISD::SDIV, VT, Custom);
// GPU does not have divrem function for signed or unsigned.
setOperationAction(ISD::SDIVREM, VT, Custom);
@ -1272,85 +1272,83 @@ SDValue AMDGPUTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
return SDValue();
}
// This is a shortcut for integer division because we have fast i32<->f32
// conversions, and fast f32 reciprocal instructions. The fractional part of a
// float is enough to accurately represent up to a 24-bit integer.
SDValue AMDGPUTargetLowering::LowerSDIV24(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT OVT = Op.getValueType();
EVT VT = Op.getValueType();
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
MVT INTTY;
MVT FLTTY;
if (!OVT.isVector()) {
INTTY = MVT::i32;
FLTTY = MVT::f32;
} else if (OVT.getVectorNumElements() == 2) {
INTTY = MVT::v2i32;
FLTTY = MVT::v2f32;
} else if (OVT.getVectorNumElements() == 4) {
INTTY = MVT::v4i32;
FLTTY = MVT::v4f32;
MVT IntVT = MVT::i32;
MVT FltVT = MVT::f32;
if (VT.isVector()) {
unsigned NElts = VT.getVectorNumElements();
IntVT = MVT::getVectorVT(MVT::i32, NElts);
FltVT = MVT::getVectorVT(MVT::f32, NElts);
}
unsigned bitsize = OVT.getScalarType().getSizeInBits();
unsigned BitSize = VT.getScalarType().getSizeInBits();
// char|short jq = ia ^ ib;
SDValue jq = DAG.getNode(ISD::XOR, DL, OVT, LHS, RHS);
SDValue jq = DAG.getNode(ISD::XOR, DL, VT, LHS, RHS);
// jq = jq >> (bitsize - 2)
jq = DAG.getNode(ISD::SRA, DL, OVT, jq, DAG.getConstant(bitsize - 2, OVT));
jq = DAG.getNode(ISD::SRA, DL, VT, jq, DAG.getConstant(BitSize - 2, VT));
// jq = jq | 0x1
jq = DAG.getNode(ISD::OR, DL, OVT, jq, DAG.getConstant(1, OVT));
jq = DAG.getNode(ISD::OR, DL, VT, jq, DAG.getConstant(1, VT));
// jq = (int)jq
jq = DAG.getSExtOrTrunc(jq, DL, INTTY);
jq = DAG.getSExtOrTrunc(jq, DL, IntVT);
// int ia = (int)LHS;
SDValue ia = DAG.getSExtOrTrunc(LHS, DL, INTTY);
SDValue ia = DAG.getSExtOrTrunc(LHS, DL, IntVT);
// int ib, (int)RHS;
SDValue ib = DAG.getSExtOrTrunc(RHS, DL, INTTY);
SDValue ib = DAG.getSExtOrTrunc(RHS, DL, IntVT);
// float fa = (float)ia;
SDValue fa = DAG.getNode(ISD::SINT_TO_FP, DL, FLTTY, ia);
SDValue fa = DAG.getNode(ISD::SINT_TO_FP, DL, FltVT, ia);
// float fb = (float)ib;
SDValue fb = DAG.getNode(ISD::SINT_TO_FP, DL, FLTTY, ib);
SDValue fb = DAG.getNode(ISD::SINT_TO_FP, DL, FltVT, ib);
// float fq = native_divide(fa, fb);
SDValue fq = DAG.getNode(ISD::FMUL, DL, FLTTY,
fa, DAG.getNode(AMDGPUISD::RCP, DL, FLTTY, fb));
SDValue fq = DAG.getNode(ISD::FMUL, DL, FltVT,
fa, DAG.getNode(AMDGPUISD::RCP, DL, FltVT, fb));
// fq = trunc(fq);
fq = DAG.getNode(ISD::FTRUNC, DL, FLTTY, fq);
fq = DAG.getNode(ISD::FTRUNC, DL, FltVT, fq);
// float fqneg = -fq;
SDValue fqneg = DAG.getNode(ISD::FNEG, DL, FLTTY, fq);
SDValue fqneg = DAG.getNode(ISD::FNEG, DL, FltVT, fq);
// float fr = mad(fqneg, fb, fa);
SDValue fr = DAG.getNode(ISD::FADD, DL, FLTTY,
DAG.getNode(ISD::MUL, DL, FLTTY, fqneg, fb), fa);
SDValue fr = DAG.getNode(ISD::FADD, DL, FltVT,
DAG.getNode(ISD::FMUL, DL, FltVT, fqneg, fb), fa);
// int iq = (int)fq;
SDValue iq = DAG.getNode(ISD::FP_TO_SINT, DL, INTTY, fq);
SDValue iq = DAG.getNode(ISD::FP_TO_SINT, DL, IntVT, fq);
// fr = fabs(fr);
fr = DAG.getNode(ISD::FABS, DL, FLTTY, fr);
fr = DAG.getNode(ISD::FABS, DL, FltVT, fr);
// fb = fabs(fb);
fb = DAG.getNode(ISD::FABS, DL, FLTTY, fb);
fb = DAG.getNode(ISD::FABS, DL, FltVT, fb);
EVT SetCCVT = getSetCCResultType(*DAG.getContext(), VT);
// int cv = fr >= fb;
SDValue cv;
if (INTTY == MVT::i32) {
cv = DAG.getSetCC(DL, INTTY, fr, fb, ISD::SETOGE);
} else {
cv = DAG.getSetCC(DL, INTTY, fr, fb, ISD::SETOGE);
}
SDValue cv = DAG.getSetCC(DL, SetCCVT, fr, fb, ISD::SETOGE);
// jq = (cv ? jq : 0);
jq = DAG.getNode(ISD::SELECT, DL, OVT, cv, jq,
DAG.getConstant(0, OVT));
jq = DAG.getNode(ISD::SELECT, DL, VT, cv, jq, DAG.getConstant(0, VT));
// dst = iq + jq;
iq = DAG.getSExtOrTrunc(iq, DL, OVT);
iq = DAG.getNode(ISD::ADD, DL, OVT, iq, jq);
return iq;
iq = DAG.getSExtOrTrunc(iq, DL, VT);
return DAG.getNode(ISD::ADD, DL, VT, iq, jq);
}
SDValue AMDGPUTargetLowering::LowerSDIV32(SDValue Op, SelectionDAG &DAG) const {
@ -1425,19 +1423,20 @@ SDValue AMDGPUTargetLowering::LowerSDIV64(SDValue Op, SelectionDAG &DAG) const {
SDValue AMDGPUTargetLowering::LowerSDIV(SDValue Op, SelectionDAG &DAG) const {
EVT OVT = Op.getValueType().getScalarType();
if (OVT == MVT::i64)
return LowerSDIV64(Op, DAG);
if (OVT == MVT::i32) {
if (DAG.ComputeNumSignBits(Op.getOperand(0)) > 8 &&
DAG.ComputeNumSignBits(Op.getOperand(1)) > 8) {
// TODO: We technically could do this for i64, but shouldn't that just be
// handled by something generally reducing 64-bit division on 32-bit
// values to 32-bit?
return LowerSDIV24(Op, DAG);
}
if (OVT.getScalarType() == MVT::i32)
return LowerSDIV32(Op, DAG);
if (OVT == MVT::i16 || OVT == MVT::i8) {
// FIXME: We should be checking for the masked bits. This isn't reached
// because i8 and i16 are not legal types.
return LowerSDIV24(Op, DAG);
}
return SDValue(Op.getNode(), 0);
assert(OVT == MVT::i64);
return LowerSDIV64(Op, DAG);
}
SDValue AMDGPUTargetLowering::LowerSREM32(SDValue Op, SelectionDAG &DAG) const {

120
test/CodeGen/R600/sdiv24.ll Normal file
View File

@ -0,0 +1,120 @@
; RUN: llc -march=r600 -mcpu=SI < %s | FileCheck -check-prefix=SI -check-prefix=FUNC %s
; RUN: llc -march=r600 -mcpu=redwood < %s | FileCheck -check-prefix=EG -check-prefix=FUNC %s
; FUNC-LABEL: @sdiv24_i8
; SI: V_CVT_F32_I32
; SI: V_CVT_F32_I32
; SI: V_RCP_F32
; SI: V_CVT_I32_F32
; EG: INT_TO_FLT
; EG-DAG: INT_TO_FLT
; EG-DAG: RECIP_IEEE
; EG: FLT_TO_INT
define void @sdiv24_i8(i8 addrspace(1)* %out, i8 addrspace(1)* %in) {
%den_ptr = getelementptr i8 addrspace(1)* %in, i8 1
%num = load i8 addrspace(1) * %in
%den = load i8 addrspace(1) * %den_ptr
%result = sdiv i8 %num, %den
store i8 %result, i8 addrspace(1)* %out
ret void
}
; FUNC-LABEL: @sdiv24_i16
; SI: V_CVT_F32_I32
; SI: V_CVT_F32_I32
; SI: V_RCP_F32
; SI: V_CVT_I32_F32
; EG: INT_TO_FLT
; EG-DAG: INT_TO_FLT
; EG-DAG: RECIP_IEEE
; EG: FLT_TO_INT
define void @sdiv24_i16(i16 addrspace(1)* %out, i16 addrspace(1)* %in) {
%den_ptr = getelementptr i16 addrspace(1)* %in, i16 1
%num = load i16 addrspace(1) * %in, align 2
%den = load i16 addrspace(1) * %den_ptr, align 2
%result = sdiv i16 %num, %den
store i16 %result, i16 addrspace(1)* %out, align 2
ret void
}
; FUNC-LABEL: @sdiv24_i32
; SI: V_CVT_F32_I32
; SI: V_CVT_F32_I32
; SI: V_RCP_F32
; SI: V_CVT_I32_F32
; EG: INT_TO_FLT
; EG-DAG: INT_TO_FLT
; EG-DAG: RECIP_IEEE
; EG: FLT_TO_INT
define void @sdiv24_i32(i32 addrspace(1)* %out, i32 addrspace(1)* %in) {
%den_ptr = getelementptr i32 addrspace(1)* %in, i32 1
%num = load i32 addrspace(1) * %in, align 4
%den = load i32 addrspace(1) * %den_ptr, align 4
%num.i24.0 = shl i32 %num, 8
%den.i24.0 = shl i32 %den, 8
%num.i24 = ashr i32 %num.i24.0, 8
%den.i24 = ashr i32 %den.i24.0, 8
%result = sdiv i32 %num.i24, %den.i24
store i32 %result, i32 addrspace(1)* %out, align 4
ret void
}
; FUNC-LABEL: @sdiv25_i32
; SI-NOT: V_CVT_F32_I32
; SI-NOT: V_RCP_F32
; EG-NOT: INT_TO_FLT
; EG-NOT: RECIP_IEEE
define void @sdiv25_i32(i32 addrspace(1)* %out, i32 addrspace(1)* %in) {
%den_ptr = getelementptr i32 addrspace(1)* %in, i32 1
%num = load i32 addrspace(1) * %in, align 4
%den = load i32 addrspace(1) * %den_ptr, align 4
%num.i24.0 = shl i32 %num, 7
%den.i24.0 = shl i32 %den, 7
%num.i24 = ashr i32 %num.i24.0, 7
%den.i24 = ashr i32 %den.i24.0, 7
%result = sdiv i32 %num.i24, %den.i24
store i32 %result, i32 addrspace(1)* %out, align 4
ret void
}
; FUNC-LABEL: @test_no_sdiv24_i32_1
; SI-NOT: V_CVT_F32_I32
; SI-NOT: V_RCP_F32
; EG-NOT: INT_TO_FLT
; EG-NOT: RECIP_IEEE
define void @test_no_sdiv24_i32_1(i32 addrspace(1)* %out, i32 addrspace(1)* %in) {
%den_ptr = getelementptr i32 addrspace(1)* %in, i32 1
%num = load i32 addrspace(1) * %in, align 4
%den = load i32 addrspace(1) * %den_ptr, align 4
%num.i24.0 = shl i32 %num, 8
%den.i24.0 = shl i32 %den, 7
%num.i24 = ashr i32 %num.i24.0, 8
%den.i24 = ashr i32 %den.i24.0, 7
%result = sdiv i32 %num.i24, %den.i24
store i32 %result, i32 addrspace(1)* %out, align 4
ret void
}
; FUNC-LABEL: @test_no_sdiv24_i32_2
; SI-NOT: V_CVT_F32_I32
; SI-NOT: V_RCP_F32
; EG-NOT: INT_TO_FLT
; EG-NOT: RECIP_IEEE
define void @test_no_sdiv24_i32_2(i32 addrspace(1)* %out, i32 addrspace(1)* %in) {
%den_ptr = getelementptr i32 addrspace(1)* %in, i32 1
%num = load i32 addrspace(1) * %in, align 4
%den = load i32 addrspace(1) * %den_ptr, align 4
%num.i24.0 = shl i32 %num, 7
%den.i24.0 = shl i32 %den, 8
%num.i24 = ashr i32 %num.i24.0, 7
%den.i24 = ashr i32 %den.i24.0, 8
%result = sdiv i32 %num.i24, %den.i24
store i32 %result, i32 addrspace(1)* %out, align 4
ret void
}