[NVPTX] Clean up handling of formal arguments and enable generation of vector parameter loads

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@185172 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Justin Holewinski 2013-06-28 17:57:53 +00:00
parent 00df125228
commit b673665143
2 changed files with 208 additions and 122 deletions

View File

@ -1066,12 +1066,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
const Function *F = MF.getFunction();
const AttributeSet &PAL = F->getAttributes();
const TargetLowering *TLI = nvTM->getTargetLowering();
SDValue Root = DAG.getRoot();
std::vector<SDValue> OutChains;
bool isKernel = llvm::isKernelFunction(*F);
bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
assert(isABI && "Non-ABI compilation is not supported");
if (!isABI)
return Chain;
std::vector<Type *> argTypes;
std::vector<const Argument *> theArgs;
@ -1080,15 +1084,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
theArgs.push_back(I);
argTypes.push_back(I->getType());
}
//assert(argTypes.size() == Ins.size() &&
// "Ins types and function types did not match");
// argTypes.size() (or theArgs.size()) and Ins.size() need not match.
// Ins.size() will be larger
// * if there is an aggregate argument with multiple fields (each field
// showing up separately in Ins)
// * if there is a vector argument with more than typical vector-length
// elements (generally if more than 4) where each vector element is
// individually present in Ins.
// So a different index should be used for indexing into Ins.
// See similar issue in LowerCall.
unsigned InsIdx = 0;
int idx = 0;
for (unsigned i = 0, e = argTypes.size(); i != e; ++i, ++idx) {
for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
Type *Ty = argTypes[i];
EVT ObjectVT = getValueType(Ty);
//assert(ObjectVT == Ins[i].VT &&
// "Ins type did not match function type");
// If the kernel argument is image*_t or sampler_t, convert it to
// a i32 constant holding the parameter position. This can later
@ -1104,142 +1113,220 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (theArgs[i]->use_empty()) {
// argument is dead
if (ObjectVT.isVector()) {
EVT EltVT = ObjectVT.getVectorElementType();
unsigned NumElts = ObjectVT.getVectorNumElements();
for (unsigned vi = 0; vi < NumElts; ++vi) {
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, EltVT));
if (Ty->isAggregateType()) {
SmallVector<EVT, 16> vtparts;
ComputeValueVTs(*this, Ty, vtparts);
assert(vtparts.size() > 0 && "empty aggregate type not expected");
for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
++parti) {
EVT partVT = vtparts[parti];
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
++InsIdx;
}
} else {
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
if (vtparts.size() > 0)
--InsIdx;
continue;
}
if (Ty->isVectorTy()) {
EVT ObjectVT = getValueType(Ty);
unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
for (unsigned parti = 0; parti < NumRegs; ++parti) {
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
++InsIdx;
}
if (NumRegs > 0)
--InsIdx;
continue;
}
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
continue;
}
// In the following cases, assign a node order of "idx+1"
// to newly created nodes. The SDNOdes for params have to
// to newly created nodes. The SDNodes for params have to
// appear in the same order as their order of appearance
// in the original function. "idx+1" holds that order.
if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
if (ObjectVT.isVector()) {
unsigned NumElts = ObjectVT.getVectorNumElements();
EVT EltVT = ObjectVT.getVectorElementType();
unsigned Offset = 0;
for (unsigned vi = 0; vi < NumElts; ++vi) {
SDValue A = getParamSymbol(DAG, idx, getPointerTy());
SDValue B = DAG.getIntPtrConstant(Offset);
SDValue Addr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
//getParamSymbol(DAG, idx, EltVT),
//DAG.getConstant(Offset, getPointerTy()));
A, B);
Value *SrcValue = Constant::getNullValue(PointerType::get(
EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
SDValue Ld = DAG.getLoad(
EltVT, dl, Root, Addr, MachinePointerInfo(SrcValue), false, false,
false,
TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
Offset += EltVT.getStoreSizeInBits() / 8;
InVals.push_back(Ld);
if (Ty->isAggregateType()) {
SmallVector<EVT, 16> vtparts;
SmallVector<uint64_t, 16> offsets;
ComputeValueVTs(*this, Ty, vtparts, &offsets, 0);
assert(vtparts.size() > 0 && "empty aggregate type not expected");
bool aggregateIsPacked = false;
if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
aggregateIsPacked = STy->isPacked();
SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
++parti) {
EVT partVT = vtparts[parti];
Value *srcValue = Constant::getNullValue(
PointerType::get(partVT.getTypeForEVT(F->getContext()),
llvm::ADDRESS_SPACE_PARAM));
SDValue srcAddr =
DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
DAG.getConstant(offsets[parti], getPointerTy()));
unsigned partAlign =
aggregateIsPacked ? 1
: TD->getABITypeAlignment(
partVT.getTypeForEVT(F->getContext()));
SDValue p = DAG.getLoad(partVT, dl, Root, srcAddr,
MachinePointerInfo(srcValue), false, false,
true, partAlign);
if (p.getNode())
p.getNode()->setIROrder(idx + 1);
InVals.push_back(p);
++InsIdx;
}
if (vtparts.size() > 0)
--InsIdx;
continue;
}
// A plain scalar.
if (isABI || isKernel) {
// If ABI, load from the param symbol
if (Ty->isVectorTy()) {
EVT ObjectVT = getValueType(Ty);
SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
// Conjure up a value that we can get the address space from.
// FIXME: Using a constant here is a hack.
Value *srcValue = Constant::getNullValue(
PointerType::get(ObjectVT.getTypeForEVT(F->getContext()),
llvm::ADDRESS_SPACE_PARAM));
SDValue p = DAG.getLoad(
ObjectVT, dl, Root, Arg, MachinePointerInfo(srcValue), false, false,
false,
TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
if (p.getNode())
p.getNode()->setIROrder(idx + 1);
InVals.push_back(p);
} else {
// If no ABI, just move the param symbol
SDValue Arg = getParamSymbol(DAG, idx, ObjectVT);
SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
if (p.getNode())
p.getNode()->setIROrder(idx + 1);
InVals.push_back(p);
unsigned NumElts = ObjectVT.getVectorNumElements();
assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
"Vector was not scalarized");
unsigned Ofst = 0;
EVT EltVT = ObjectVT.getVectorElementType();
// V1 load
// f32 = load ...
if (NumElts == 1) {
// We only have one element, so just directly load it
Value *SrcValue = Constant::getNullValue(PointerType::get(
EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
DAG.getConstant(Ofst, getPointerTy()));
SDValue P = DAG.getLoad(
EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
false, true,
TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
if (P.getNode())
P.getNode()->setIROrder(idx + 1);
InVals.push_back(P);
Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
++InsIdx;
} else if (NumElts == 2) {
// V2 load
// f32,f32 = load ...
EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
Value *SrcValue = Constant::getNullValue(PointerType::get(
VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
DAG.getConstant(Ofst, getPointerTy()));
SDValue P = DAG.getLoad(
VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
false, true,
TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
if (P.getNode())
P.getNode()->setIROrder(idx + 1);
SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
DAG.getIntPtrConstant(0));
SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
DAG.getIntPtrConstant(1));
InVals.push_back(Elt0);
InVals.push_back(Elt1);
Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
InsIdx += 2;
} else {
// V4 loads
// We have at least 4 elements (<3 x Ty> expands to 4 elements) and
// the
// vector will be expanded to a power of 2 elements, so we know we can
// always round up to the next multiple of 4 when creating the vector
// loads.
// e.g. 4 elem => 1 ld.v4
// 6 elem => 2 ld.v4
// 8 elem => 2 ld.v4
// 11 elem => 3 ld.v4
unsigned VecSize = 4;
if (EltVT.getSizeInBits() == 64) {
VecSize = 2;
}
EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
for (unsigned i = 0; i < NumElts; i += VecSize) {
Value *SrcValue = Constant::getNullValue(
PointerType::get(VecVT.getTypeForEVT(F->getContext()),
llvm::ADDRESS_SPACE_PARAM));
SDValue SrcAddr =
DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
DAG.getConstant(Ofst, getPointerTy()));
SDValue P = DAG.getLoad(
VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
false, true,
TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
if (P.getNode())
P.getNode()->setIROrder(idx + 1);
for (unsigned j = 0; j < VecSize; ++j) {
if (i + j >= NumElts)
break;
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
DAG.getIntPtrConstant(j));
InVals.push_back(Elt);
}
Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
InsIdx += VecSize;
}
}
if (NumElts > 0)
--InsIdx;
continue;
}
// A plain scalar.
EVT ObjectVT = getValueType(Ty);
assert(ObjectVT == Ins[InsIdx].VT &&
"Ins type did not match function type");
// If ABI, load from the param symbol
SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
Value *srcValue = Constant::getNullValue(PointerType::get(
ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
SDValue p = DAG.getLoad(
ObjectVT, dl, Root, Arg, MachinePointerInfo(srcValue), false, false,
true,
TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
if (p.getNode())
p.getNode()->setIROrder(idx + 1);
InVals.push_back(p);
continue;
}
// Param has ByVal attribute
if (isABI || isKernel) {
// Return MoveParam(param symbol).
// Ideally, the param symbol can be returned directly,
// but when SDNode builder decides to use it in a CopyToReg(),
// machine instruction fails because TargetExternalSymbol
// (not lowered) is target dependent, and CopyToReg assumes
// the source is lowered.
SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
if (p.getNode())
p.getNode()->setIROrder(idx + 1);
if (isKernel)
InVals.push_back(p);
else {
SDValue p2 = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
InVals.push_back(p2);
}
} else {
// Have to move a set of param symbols to registers and
// store them locally and return the local pointer in InVals
const PointerType *elemPtrType = dyn_cast<PointerType>(argTypes[i]);
assert(elemPtrType && "Byval parameter should be a pointer type");
Type *elemType = elemPtrType->getElementType();
// Compute the constituent parts
SmallVector<EVT, 16> vtparts;
SmallVector<uint64_t, 16> offsets;
ComputeValueVTs(*this, elemType, vtparts, &offsets, 0);
unsigned totalsize = 0;
for (unsigned j = 0, je = vtparts.size(); j != je; ++j)
totalsize += vtparts[j].getStoreSizeInBits();
SDValue localcopy = DAG.getFrameIndex(
MF.getFrameInfo()->CreateStackObject(totalsize / 8, 16, false),
getPointerTy());
unsigned sizesofar = 0;
std::vector<SDValue> theChains;
for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
unsigned numElems = 1;
if (vtparts[j].isVector())
numElems = vtparts[j].getVectorNumElements();
for (unsigned k = 0, ke = numElems; k != ke; ++k) {
EVT tmpvt = vtparts[j];
if (tmpvt.isVector())
tmpvt = tmpvt.getVectorElementType();
SDValue arg = DAG.getNode(NVPTXISD::MoveParam, dl, tmpvt,
getParamSymbol(DAG, idx, tmpvt));
SDValue addr =
DAG.getNode(ISD::ADD, dl, getPointerTy(), localcopy,
DAG.getConstant(sizesofar, getPointerTy()));
theChains.push_back(DAG.getStore(
Chain, dl, arg, addr, MachinePointerInfo(), false, false, 0));
sizesofar += tmpvt.getStoreSizeInBits() / 8;
++idx;
}
}
--idx;
Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &theChains[0],
theChains.size());
InVals.push_back(localcopy);
// Return MoveParam(param symbol).
// Ideally, the param symbol can be returned directly,
// but when SDNode builder decides to use it in a CopyToReg(),
// machine instruction fails because TargetExternalSymbol
// (not lowered) is target dependent, and CopyToReg assumes
// the source is lowered.
EVT ObjectVT = getValueType(Ty);
assert(ObjectVT == Ins[InsIdx].VT &&
"Ins type did not match function type");
SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
if (p.getNode())
p.getNode()->setIROrder(idx + 1);
if (isKernel)
InVals.push_back(p);
else {
SDValue p2 = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
InVals.push_back(p2);
}
}
// Clang will check explicit VarArg and issue error if any. However, Clang
// will let code with
// implicit var arg like f() pass.
// implicit var arg like f() pass. See bug 617733.
// We treat this case as if the arg list is empty.
//if (F.isVarArg()) {
// if (F.isVarArg()) {
// assert(0 && "VarArg not supported yet!");
//}
@ -1250,6 +1337,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
return Chain;
}
SDValue NVPTXTargetLowering::LowerReturn(
SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,

View File

@ -4,8 +4,7 @@
define float @foo(<2 x float> %a) {
; CHECK: .func (.param .b32 func_retval0) foo
; CHECK: .param .align 8 .b8 foo_param_0[8]
; CHECK: ld.param.f32 %f{{[0-9]+}}
; CHECK: ld.param.f32 %f{{[0-9]+}}
; CHECK: ld.param.v2.f32 {%f{{[0-9]+}}, %f{{[0-9]+}}}
%t1 = fmul <2 x float> %a, %a
%t2 = extractelement <2 x float> %t1, i32 0
%t3 = extractelement <2 x float> %t1, i32 1
@ -17,8 +16,7 @@ define float @foo(<2 x float> %a) {
define float @bar(<4 x float> %a) {
; CHECK: .func (.param .b32 func_retval0) bar
; CHECK: .param .align 16 .b8 bar_param_0[16]
; CHECK: ld.param.f32 %f{{[0-9]+}}
; CHECK: ld.param.f32 %f{{[0-9]+}}
; CHECK: ld.param.v4.f32 {%f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}}
%t1 = fmul <4 x float> %a, %a
%t2 = extractelement <4 x float> %t1, i32 0
%t3 = extractelement <4 x float> %t1, i32 1