diff --git a/include/llvm/IR/Operator.h b/include/llvm/IR/Operator.h index 46935ce3e96..c87f89c6157 100644 --- a/include/llvm/IR/Operator.h +++ b/include/llvm/IR/Operator.h @@ -400,6 +400,11 @@ public: return getPointerOperand()->getType(); } + Type *getSourceElementType() const { + return cast(getPointerOperandType()->getScalarType()) + ->getElementType(); + } + /// Method to return the address space of the pointer operand. unsigned getPointerAddressSpace() const { return getPointerOperandType()->getPointerAddressSpace(); diff --git a/lib/Bitcode/Reader/BitcodeReader.cpp b/lib/Bitcode/Reader/BitcodeReader.cpp index 33b02f912f0..9a0ec19e6f3 100644 --- a/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1955,19 +1955,25 @@ std::error_code BitcodeReader::ParseConstants() { } case bitc::CST_CODE_CE_INBOUNDS_GEP: case bitc::CST_CODE_CE_GEP: { // CE_GEP: [n x operands] - if (Record.size() & 1) - return Error("Invalid record"); + unsigned OpNum = 0; + Type *PointeeType = nullptr; + if (Record.size() % 2) + PointeeType = getTypeByID(Record[OpNum++]); SmallVector Elts; - for (unsigned i = 0, e = Record.size(); i != e; i += 2) { - Type *ElTy = getTypeByID(Record[i]); + while (OpNum != Record.size()) { + Type *ElTy = getTypeByID(Record[OpNum++]); if (!ElTy) return Error("Invalid record"); - Elts.push_back(ValueList.getConstantFwdRef(Record[i+1], ElTy)); + Elts.push_back(ValueList.getConstantFwdRef(Record[OpNum++], ElTy)); } + ArrayRef Indices(Elts.begin() + 1, Elts.end()); V = ConstantExpr::getGetElementPtr(Elts[0], Indices, BitCode == bitc::CST_CODE_CE_INBOUNDS_GEP); + if (PointeeType && + PointeeType != cast(V)->getSourceElementType()) + return Error("Invalid record"); break; } case bitc::CST_CODE_CE_SELECT: { // CE_SELECT: [opval#, opval#, opval#] diff --git a/lib/Bitcode/Writer/BitcodeWriter.cpp b/lib/Bitcode/Writer/BitcodeWriter.cpp index ecb6f7c130a..d2417acc689 100644 --- a/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -1522,15 +1522,18 @@ static void WriteConstants(unsigned FirstVal, unsigned LastVal, Record.push_back(Flags); } break; - case Instruction::GetElementPtr: + case Instruction::GetElementPtr: { Code = bitc::CST_CODE_CE_GEP; - if (cast(C)->isInBounds()) + const auto *GO = cast(C); + if (GO->isInBounds()) Code = bitc::CST_CODE_CE_INBOUNDS_GEP; + Record.push_back(VE.getTypeID(GO->getSourceElementType())); for (unsigned i = 0, e = CE->getNumOperands(); i != e; ++i) { Record.push_back(VE.getTypeID(C->getOperand(i)->getType())); Record.push_back(VE.getValueID(C->getOperand(i))); } break; + } case Instruction::Select: Code = bitc::CST_CODE_CE_SELECT; Record.push_back(VE.getValueID(C->getOperand(0)));