diff --git a/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 20b7ce6b15b..8464b7d8083 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -578,6 +578,7 @@ private: // Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>. bool SplitVectorOperand(SDNode *N, unsigned OpNo); + SDValue SplitVecOp_VSELECT(SDNode *N, unsigned OpNo); SDValue SplitVecOp_UnaryOp(SDNode *N); SDValue SplitVecOp_BITCAST(SDNode *N); diff --git a/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index d51a6eb192e..595d83b716b 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1030,7 +1030,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) { case ISD::STORE: Res = SplitVecOp_STORE(cast(N), OpNo); break; - + case ISD::VSELECT: + Res = SplitVecOp_VSELECT(N, OpNo); + break; case ISD::CTTZ: case ISD::CTLZ: case ISD::CTPOP: @@ -1064,6 +1066,62 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) { return false; } +SDValue DAGTypeLegalizer::SplitVecOp_VSELECT(SDNode *N, unsigned OpNo) { + // The only possibility for an illegal operand is the mask, since result type + // legalization would have handled this node already otherwise. + assert(OpNo == 0 && "Illegal operand must be mask"); + + SDValue Mask = N->getOperand(0); + SDValue Src0 = N->getOperand(1); + SDValue Src1 = N->getOperand(2); + DebugLoc DL = N->getDebugLoc(); + EVT MaskVT = Mask.getValueType(); + assert(MaskVT.isVector() && "VSELECT without a vector mask?"); + + SDValue Lo, Hi; + GetSplitVector(N->getOperand(0), Lo, Hi); + + unsigned LoNumElts = Lo.getValueType().getVectorNumElements(); + unsigned HiNumElts = Hi.getValueType().getVectorNumElements(); + assert(LoNumElts == HiNumElts && "Asymmetric vector split?"); + + EVT LoOpVT = EVT::getVectorVT(*DAG.getContext(), + Src0.getValueType().getVectorElementType(), + LoNumElts); + EVT LoMaskVT = EVT::getVectorVT(*DAG.getContext(), + MaskVT.getVectorElementType(), + LoNumElts); + EVT HiOpVT = EVT::getVectorVT(*DAG.getContext(), + Src0.getValueType().getVectorElementType(), + HiNumElts); + EVT HiMaskVT = EVT::getVectorVT(*DAG.getContext(), + MaskVT.getVectorElementType(), + HiNumElts); + + SDValue LoOp0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoOpVT, Src0, + DAG.getIntPtrConstant(0)); + SDValue LoOp1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoOpVT, Src1, + DAG.getIntPtrConstant(0)); + + SDValue HiOp0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiOpVT, Src0, + DAG.getIntPtrConstant(LoNumElts)); + SDValue HiOp1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiOpVT, Src1, + DAG.getIntPtrConstant(LoNumElts)); + + SDValue LoMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoMaskVT, Mask, + DAG.getIntPtrConstant(0)); + SDValue HiMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiMaskVT, Mask, + DAG.getIntPtrConstant(LoNumElts)); + + SDValue LoSelect = DAG.getNode(ISD::VSELECT, DL, LoOpVT, LoMask, LoOp0, + LoOp1); + SDValue HiSelect = DAG.getNode(ISD::VSELECT, DL, HiOpVT, HiMask, HiOp0, + HiOp1); + + return DAG.getNode(ISD::CONCAT_VECTORS, DL, Src0.getValueType(), LoSelect, + HiSelect); +} + SDValue DAGTypeLegalizer::SplitVecOp_UnaryOp(SDNode *N) { // The result has a legal vector type, but the input needs splitting. EVT ResVT = N->getValueType(0); diff --git a/test/CodeGen/NVPTX/vector-select.ll b/test/CodeGen/NVPTX/vector-select.ll new file mode 100644 index 00000000000..11893df1032 --- /dev/null +++ b/test/CodeGen/NVPTX/vector-select.ll @@ -0,0 +1,16 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 +; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 + +; This test makes sure that vector selects are scalarized by the type legalizer. +; If not, type legalization will fail. + +define void @foo(<2 x i32> addrspace(1)* %def_a, <2 x i32> addrspace(1)* %def_b, <2 x i32> addrspace(1)* %def_c) { +entry: + %tmp4 = load <2 x i32> addrspace(1)* %def_a + %tmp6 = load <2 x i32> addrspace(1)* %def_c + %tmp8 = load <2 x i32> addrspace(1)* %def_b + %0 = icmp sge <2 x i32> %tmp4, zeroinitializer + %cond = select <2 x i1> %0, <2 x i32> %tmp6, <2 x i32> %tmp8 + store <2 x i32> %cond, <2 x i32> addrspace(1)* %def_c + ret void +}