diff --git a/lib/VMCore/Instructions.cpp b/lib/VMCore/Instructions.cpp index af1492d356a..91f0febf349 100644 --- a/lib/VMCore/Instructions.cpp +++ b/lib/VMCore/Instructions.cpp @@ -2076,6 +2076,7 @@ unsigned CastInst::isEliminableCastPair( CastInst *CastInst::Create(Instruction::CastOps op, Value *S, const Type *Ty, const Twine &Name, Instruction *InsertBefore) { + assert(castIsValid(op, S, Ty) && "Invalid cast!"); // Construct and return the appropriate CastInst subclass switch (op) { case Trunc: return new TruncInst (S, Ty, Name, InsertBefore); @@ -2098,6 +2099,7 @@ CastInst *CastInst::Create(Instruction::CastOps op, Value *S, const Type *Ty, CastInst *CastInst::Create(Instruction::CastOps op, Value *S, const Type *Ty, const Twine &Name, BasicBlock *InsertAtEnd) { + assert(castIsValid(op, S, Ty) && "Invalid cast!"); // Construct and return the appropriate CastInst subclass switch (op) { case Trunc: return new TruncInst (S, Ty, Name, InsertAtEnd); @@ -2263,8 +2265,8 @@ bool CastInst::isCastable(const Type *SrcTy, const Type *DestTy) { } // Get the bit sizes, we'll need these - unsigned SrcBits = SrcTy->getScalarSizeInBits(); // 0 for ptr - unsigned DestBits = DestTy->getScalarSizeInBits(); // 0 for ptr + unsigned SrcBits = SrcTy->getPrimitiveSizeInBits(); // 0 for ptr + unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr // Run through the possibilities ... if (DestTy->isIntegerTy()) { // Casting to integral @@ -2348,8 +2350,8 @@ CastInst::getCastOpcode( } // Get the bit sizes, we'll need these - unsigned SrcBits = SrcTy->getScalarSizeInBits(); // 0 for ptr - unsigned DestBits = DestTy->getScalarSizeInBits(); // 0 for ptr + unsigned SrcBits = SrcTy->getPrimitiveSizeInBits(); // 0 for ptr + unsigned DestBits = DestTy->getPrimitiveSizeInBits(); // 0 for ptr // Run through the possibilities ... if (DestTy->isIntegerTy()) { // Casting to integral @@ -2463,46 +2465,40 @@ CastInst::castIsValid(Instruction::CastOps op, Value *S, const Type *DstTy) { unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); unsigned DstBitSize = DstTy->getScalarSizeInBits(); + // If these are vector types, get the lengths of the vectors (using zero for + // scalar types means that checking that vector lengths match also checks that + // scalars are not being converted to vectors or vectors to scalars). + unsigned SrcLength = SrcTy->isVectorTy() ? + cast(SrcTy)->getNumElements() : 0; + unsigned DstLength = DstTy->isVectorTy() ? + cast(DstTy)->getNumElements() : 0; + // Switch on the opcode provided switch (op) { default: return false; // This is an input error case Instruction::Trunc: - return SrcTy->isIntOrIntVectorTy() && - DstTy->isIntOrIntVectorTy()&& SrcBitSize > DstBitSize; + return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() && + SrcLength == DstLength && SrcBitSize > DstBitSize; case Instruction::ZExt: - return SrcTy->isIntOrIntVectorTy() && - DstTy->isIntOrIntVectorTy()&& SrcBitSize < DstBitSize; + return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() && + SrcLength == DstLength && SrcBitSize < DstBitSize; case Instruction::SExt: - return SrcTy->isIntOrIntVectorTy() && - DstTy->isIntOrIntVectorTy()&& SrcBitSize < DstBitSize; + return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() && + SrcLength == DstLength && SrcBitSize < DstBitSize; case Instruction::FPTrunc: - return SrcTy->isFPOrFPVectorTy() && - DstTy->isFPOrFPVectorTy() && - SrcBitSize > DstBitSize; + return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() && + SrcLength == DstLength && SrcBitSize > DstBitSize; case Instruction::FPExt: - return SrcTy->isFPOrFPVectorTy() && - DstTy->isFPOrFPVectorTy() && - SrcBitSize < DstBitSize; + return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() && + SrcLength == DstLength && SrcBitSize < DstBitSize; case Instruction::UIToFP: case Instruction::SIToFP: - if (const VectorType *SVTy = dyn_cast(SrcTy)) { - if (const VectorType *DVTy = dyn_cast(DstTy)) { - return SVTy->getElementType()->isIntOrIntVectorTy() && - DVTy->getElementType()->isFPOrFPVectorTy() && - SVTy->getNumElements() == DVTy->getNumElements(); - } - } - return SrcTy->isIntOrIntVectorTy() && DstTy->isFPOrFPVectorTy(); + return SrcTy->isIntOrIntVectorTy() && DstTy->isFPOrFPVectorTy() && + SrcLength == DstLength; case Instruction::FPToUI: case Instruction::FPToSI: - if (const VectorType *SVTy = dyn_cast(SrcTy)) { - if (const VectorType *DVTy = dyn_cast(DstTy)) { - return SVTy->getElementType()->isFPOrFPVectorTy() && - DVTy->getElementType()->isIntOrIntVectorTy() && - SVTy->getNumElements() == DVTy->getNumElements(); - } - } - return SrcTy->isFPOrFPVectorTy() && DstTy->isIntOrIntVectorTy(); + return SrcTy->isFPOrFPVectorTy() && DstTy->isIntOrIntVectorTy() && + SrcLength == DstLength; case Instruction::PtrToInt: return SrcTy->isPointerTy() && DstTy->isIntegerTy(); case Instruction::IntToPtr: diff --git a/test/Assembler/invalid_cast.ll b/test/Assembler/invalid_cast.ll new file mode 100644 index 00000000000..c5b082b6b8d --- /dev/null +++ b/test/Assembler/invalid_cast.ll @@ -0,0 +1,6 @@ +; RUN: not llvm-as < %s |& grep {invalid cast opcode} + +define <3 x i8> @foo(<4 x i64> %x) { + %y = trunc <4 x i64> %x to <3 x i8> + ret <3 x i8> %y +} diff --git a/test/Assembler/invalid_cast2.ll b/test/Assembler/invalid_cast2.ll new file mode 100644 index 00000000000..f2e7c414e71 --- /dev/null +++ b/test/Assembler/invalid_cast2.ll @@ -0,0 +1,6 @@ +; RUN: not llvm-as < %s |& grep {invalid cast opcode} + +define i8 @foo(<4 x i64> %x) { + %y = trunc <4 x i64> %x to i8 + ret i8 %y +}