diff --git a/lib/Target/AMDGPU/AMDGPU.td b/lib/Target/AMDGPU/AMDGPU.td index ef8ef626854..68b50504ee4 100644 --- a/lib/Target/AMDGPU/AMDGPU.td +++ b/lib/Target/AMDGPU/AMDGPU.td @@ -123,6 +123,11 @@ def FeatureSGPRInitBug : SubtargetFeature<"sgpr-init-bug", "true", "VI SGPR initilization bug requiring a fixed SGPR allocation size">; +def FeatureEnableHugeScratchBuffer : SubtargetFeature<"huge-scratch-buffer", + "EnableHugeScratchBuffer", + "true", + "Enable scratch buffer sizes greater than 128 GB">; + class SubtargetFeatureFetchLimit : SubtargetFeature <"fetch"#Value, "TexVTXClauseSize", diff --git a/lib/Target/AMDGPU/AMDGPUSubtarget.cpp b/lib/Target/AMDGPU/AMDGPUSubtarget.cpp index bd5abc4f546..5f32a65c933 100644 --- a/lib/Target/AMDGPU/AMDGPUSubtarget.cpp +++ b/lib/Target/AMDGPU/AMDGPUSubtarget.cpp @@ -73,7 +73,7 @@ AMDGPUSubtarget::AMDGPUSubtarget(const Triple &TT, StringRef GPU, StringRef FS, WavefrontSize(0), CFALUBug(false), LocalMemorySize(0), EnableVGPRSpilling(false), SGPRInitBug(false), IsGCN(false), GCN1Encoding(false), GCN3Encoding(false), CIInsts(false), LDSBankCount(0), - IsaVersion(ISAVersion0_0_0), + IsaVersion(ISAVersion0_0_0), EnableHugeScratchBuffer(false), FrameLowering(TargetFrameLowering::StackGrowsUp, 64 * 16, // Maximum stack alignment (long16) 0), diff --git a/lib/Target/AMDGPU/AMDGPUSubtarget.h b/lib/Target/AMDGPU/AMDGPUSubtarget.h index 90831bfb445..735f01dfa7c 100644 --- a/lib/Target/AMDGPU/AMDGPUSubtarget.h +++ b/lib/Target/AMDGPU/AMDGPUSubtarget.h @@ -89,6 +89,7 @@ private: bool FeatureDisable; int LDSBankCount; unsigned IsaVersion; + bool EnableHugeScratchBuffer; AMDGPUFrameLowering FrameLowering; std::unique_ptr TLInfo; @@ -271,6 +272,10 @@ public: return DevName; } + bool enableHugeScratchBuffer() const { + return EnableHugeScratchBuffer; + } + bool dumpCode() const { return DumpCode; } diff --git a/lib/Target/AMDGPU/SIISelLowering.cpp b/lib/Target/AMDGPU/SIISelLowering.cpp index dd818a9ba74..8ae687c5e82 100644 --- a/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/lib/Target/AMDGPU/SIISelLowering.cpp @@ -812,10 +812,29 @@ static SDNode *findUser(SDValue Value, unsigned Opcode) { SDValue SITargetLowering::LowerFrameIndex(SDValue Op, SelectionDAG &DAG) const { + SDLoc SL(Op); FrameIndexSDNode *FINode = cast(Op); unsigned FrameIndex = FINode->getIndex(); - return DAG.getTargetFrameIndex(FrameIndex, MVT::i32); + // A FrameIndex node represents a 32-bit offset into scratch memory. If + // the high bit of a frame index offset were to be set, this would mean + // that it represented an offset of ~2GB * 64 = ~128GB from the start of the + // scratch buffer, with 64 being the number of threads per wave. + // + // If we know the machine uses less than 128GB of scratch, then we can + // amrk the high bit of the FrameIndex node as known zero, + // which is important, because it means in most situations we can + // prove that values derived from FrameIndex nodes are non-negative. + // This enables us to take advantage of more addressing modes when + // accessing scratch buffers, since for scratch reads/writes, the register + // offset must always be positive. + + SDValue TFI = DAG.getTargetFrameIndex(FrameIndex, MVT::i32); + if (Subtarget->enableHugeScratchBuffer()) + return TFI; + + return DAG.getNode(ISD::AssertZext, SL, MVT::i32, TFI, + DAG.getValueType(EVT::getIntegerVT(*DAG.getContext(), 31))); } /// This transforms the control flow intrinsics to get the branch destination as @@ -2034,6 +2053,13 @@ void SITargetLowering::adjustWritemask(MachineSDNode *&Node, } } +static bool isFrameIndexOp(SDValue Op) { + if (Op.getOpcode() == ISD::AssertZext) + Op = Op.getOperand(0); + + return isa(Op); +} + /// \brief Legalize target independent instructions (e.g. INSERT_SUBREG) /// with frame index operands. /// LLVM assumes that inputs are to these instructions are registers. @@ -2042,7 +2068,7 @@ void SITargetLowering::legalizeTargetIndependentNode(SDNode *Node, SmallVector Ops; for (unsigned i = 0; i < Node->getNumOperands(); ++i) { - if (!isa(Node->getOperand(i))) { + if (!isFrameIndexOp(Node->getOperand(i))) { Ops.push_back(Node->getOperand(i)); continue; }