mirror of
				https://github.com/c64scene-ar/llvm-6502.git
				synced 2025-10-31 08:16:47 +00:00 
			
		
		
		
	git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@221710 91177308-0d34-0410-b5e6-96231b3b80d8
		
			
				
	
	
		
			375 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			375 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| //===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===//
 | |
| //
 | |
| // This file is distributed under the University of Illinois Open Source
 | |
| // License. See LICENSE.TXT for details.
 | |
| //
 | |
| //===----------------------------------------------------------------------===//
 | |
| ///
 | |
| /// \file
 | |
| /// \brief A pass that instruments code with fast checks for indirect calls and
 | |
| /// hooks for a function to check violations.
 | |
| ///
 | |
| //===----------------------------------------------------------------------===//
 | |
| 
 | |
| #define DEBUG_TYPE "cfi"
 | |
| 
 | |
| #include "llvm/ADT/SmallVector.h"
 | |
| #include "llvm/ADT/Statistic.h"
 | |
| #include "llvm/Analysis/JumpInstrTableInfo.h"
 | |
| #include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
 | |
| #include "llvm/CodeGen/JumpInstrTables.h"
 | |
| #include "llvm/CodeGen/Passes.h"
 | |
| #include "llvm/IR/Attributes.h"
 | |
| #include "llvm/IR/CallSite.h"
 | |
| #include "llvm/IR/Constants.h"
 | |
| #include "llvm/IR/DerivedTypes.h"
 | |
| #include "llvm/IR/Function.h"
 | |
| #include "llvm/IR/GlobalValue.h"
 | |
| #include "llvm/IR/Instructions.h"
 | |
| #include "llvm/IR/InlineAsm.h"
 | |
| #include "llvm/IR/IRBuilder.h"
 | |
| #include "llvm/IR/LLVMContext.h"
 | |
| #include "llvm/IR/Module.h"
 | |
| #include "llvm/IR/Operator.h"
 | |
| #include "llvm/IR/Type.h"
 | |
| #include "llvm/IR/Verifier.h"
 | |
| #include "llvm/Pass.h"
 | |
| #include "llvm/Support/CommandLine.h"
 | |
| #include "llvm/Support/Debug.h"
 | |
| #include "llvm/Support/raw_ostream.h"
 | |
| 
 | |
| using namespace llvm;
 | |
| 
 | |
| STATISTIC(NumCFIIndirectCalls,
 | |
|           "Number of indirect call sites rewritten by the CFI pass");
 | |
| 
 | |
| char ForwardControlFlowIntegrity::ID = 0;
 | |
| INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi",
 | |
|                       "Control-Flow Integrity", true, true)
 | |
| INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
 | |
| INITIALIZE_PASS_DEPENDENCY(JumpInstrTables);
 | |
| INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi",
 | |
|                     "Control-Flow Integrity", true, true)
 | |
| 
 | |
| ModulePass *llvm::createForwardControlFlowIntegrityPass() {
 | |
|   return new ForwardControlFlowIntegrity();
 | |
| }
 | |
| 
 | |
| ModulePass *llvm::createForwardControlFlowIntegrityPass(
 | |
|     JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
 | |
|     StringRef CFIFuncName) {
 | |
|   return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing,
 | |
|                                          CFIFuncName);
 | |
| }
 | |
| 
 | |
| // Checks to see if a given CallSite is making an indirect call, including
 | |
| // cases where the indirect call is made through a bitcast.
 | |
| static bool isIndirectCall(CallSite &CS) {
 | |
|   if (CS.getCalledFunction())
 | |
|     return false;
 | |
| 
 | |
|   // Check the value to see if it is merely a bitcast of a function. In
 | |
|   // this case, it will translate to a direct function call in the resulting
 | |
|   // assembly, so we won't treat it as an indirect call here.
 | |
|   const Value *V = CS.getCalledValue();
 | |
|   if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
 | |
|     return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
 | |
|   }
 | |
| 
 | |
|   // Otherwise, since we know it's a call, it must be an indirect call
 | |
|   return true;
 | |
| }
 | |
| 
 | |
| static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning";
 | |
| 
 | |
| ForwardControlFlowIntegrity::ForwardControlFlowIntegrity()
 | |
|     : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single),
 | |
|       CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") {
 | |
|   initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
 | |
| }
 | |
| 
 | |
| ForwardControlFlowIntegrity::ForwardControlFlowIntegrity(
 | |
|     JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
 | |
|     std::string CFIFuncName)
 | |
|     : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType),
 | |
|       CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) {
 | |
|   initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
 | |
| }
 | |
| 
 | |
| ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {}
 | |
| 
 | |
| void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const {
 | |
|   AU.addRequired<JumpInstrTableInfo>();
 | |
|   AU.addRequired<JumpInstrTables>();
 | |
| }
 | |
| 
 | |
| void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) {
 | |
|   // To get the indirect calls, we iterate over all functions and iterate over
 | |
|   // the list of basic blocks in each. We extract a total list of indirect calls
 | |
|   // before modifying any of them, since our modifications will modify the list
 | |
|   // of basic blocks.
 | |
|   for (Function &F : M) {
 | |
|     for (BasicBlock &BB : F) {
 | |
|       for (Instruction &I : BB) {
 | |
|         CallSite CS(&I);
 | |
|         if (!(CS && isIndirectCall(CS)))
 | |
|           continue;
 | |
| 
 | |
|         Value *CalledValue = CS.getCalledValue();
 | |
| 
 | |
|         // Don't rewrite this instruction if the indirect call is actually just
 | |
|         // inline assembly, since our transformation will generate an invalid
 | |
|         // module in that case.
 | |
|         if (isa<InlineAsm>(CalledValue))
 | |
|           continue;
 | |
| 
 | |
|         IndirectCalls.push_back(&I);
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M,
 | |
|                                                       CFITables &CFIT) {
 | |
|   Type *Int64Ty = Type::getInt64Ty(M.getContext());
 | |
|   for (Instruction *I : IndirectCalls) {
 | |
|     CallSite CS(I);
 | |
|     Value *CalledValue = CS.getCalledValue();
 | |
| 
 | |
|     // Get the function type for this call and look it up in the tables.
 | |
|     Type *VTy = CalledValue->getType();
 | |
|     PointerType *PTy = dyn_cast<PointerType>(VTy);
 | |
|     Type *EltTy = PTy->getElementType();
 | |
|     FunctionType *FunTy = dyn_cast<FunctionType>(EltTy);
 | |
|     FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy);
 | |
|     ++NumCFIIndirectCalls;
 | |
|     Constant *JumpTableStart = nullptr;
 | |
|     Constant *JumpTableMask = nullptr;
 | |
|     Constant *JumpTableSize = nullptr;
 | |
| 
 | |
|     // Some call sites have function types that don't correspond to any
 | |
|     // address-taken function in the module. This happens when function pointers
 | |
|     // are passed in from external code.
 | |
|     auto it = CFIT.find(TransformedTy);
 | |
|     if (it == CFIT.end()) {
 | |
|       // In this case, make sure that the function pointer will change by
 | |
|       // setting the mask and the start to be 0 so that the transformed
 | |
|       // function is 0.
 | |
|       JumpTableStart = ConstantInt::get(Int64Ty, 0);
 | |
|       JumpTableMask = ConstantInt::get(Int64Ty, 0);
 | |
|       JumpTableSize = ConstantInt::get(Int64Ty, 0);
 | |
|     } else {
 | |
|       JumpTableStart = it->second.StartValue;
 | |
|       JumpTableMask = it->second.MaskValue;
 | |
|       JumpTableSize = it->second.Size;
 | |
|     }
 | |
| 
 | |
|     rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask,
 | |
|                            JumpTableSize);
 | |
|   }
 | |
| 
 | |
|   return;
 | |
| }
 | |
| 
 | |
| bool ForwardControlFlowIntegrity::runOnModule(Module &M) {
 | |
|   JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>();
 | |
|   Type *Int64Ty = Type::getInt64Ty(M.getContext());
 | |
|   Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
 | |
| 
 | |
|   // JumpInstrTableInfo stores information about the alignment of each entry.
 | |
|   // The alignment returned by JumpInstrTableInfo is alignment in bytes, not
 | |
|   // in the exponent.
 | |
|   ByteAlignment = JITI->entryByteAlignment();
 | |
|   LogByteAlignment = llvm::Log2_64(ByteAlignment);
 | |
| 
 | |
|   // Set up tables for control-flow integrity based on information about the
 | |
|   // jump-instruction tables.
 | |
|   CFITables CFIT;
 | |
|   for (const auto &KV : JITI->getTables()) {
 | |
|     uint64_t Size = static_cast<uint64_t>(KV.second.size());
 | |
|     uint64_t TableSize = NextPowerOf2(Size);
 | |
| 
 | |
|     int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment;
 | |
|     Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue);
 | |
|     Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size);
 | |
| 
 | |
|     // The base of the table is defined to be the first jumptable function in
 | |
|     // the table.
 | |
|     Function *First = KV.second.begin()->second;
 | |
|     Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy);
 | |
|     CFIT[KV.first].StartValue = JumpTableStartValue;
 | |
|     CFIT[KV.first].MaskValue = JumpTableMaskValue;
 | |
|     CFIT[KV.first].Size = JumpTableSize;
 | |
|   }
 | |
| 
 | |
|   if (CFIT.empty())
 | |
|     return false;
 | |
| 
 | |
|   getIndirectCalls(M);
 | |
| 
 | |
|   if (!CFIEnforcing) {
 | |
|     addWarningFunction(M);
 | |
|   }
 | |
| 
 | |
|   // Update the instructions with the check and the indirect jump through our
 | |
|   // table.
 | |
|   updateIndirectCalls(M, CFIT);
 | |
| 
 | |
|   return true;
 | |
| }
 | |
| 
 | |
| void ForwardControlFlowIntegrity::addWarningFunction(Module &M) {
 | |
|   PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext());
 | |
| 
 | |
|   // Get the type of the Warning Function: void (i8*, i8*),
 | |
|   // where the first argument is the name of the function in which the violation
 | |
|   // occurs, and the second is the function pointer that violates CFI.
 | |
|   SmallVector<Type *, 2> WarningFunArgs;
 | |
|   WarningFunArgs.push_back(CharPtrTy);
 | |
|   WarningFunArgs.push_back(CharPtrTy);
 | |
|   FunctionType *WarningFunTy =
 | |
|       FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false);
 | |
| 
 | |
|   if (!CFIFuncName.empty()) {
 | |
|     Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy);
 | |
|     if (!FailureFun)
 | |
|       report_fatal_error("Could not get or insert the function specified by"
 | |
|                          " -cfi-func-name");
 | |
|   } else {
 | |
|     // The default warning function swallows the warning and lets the call
 | |
|     // continue, since there's no generic way for it to print out this
 | |
|     // information.
 | |
|     Function *WarningFun = M.getFunction(cfi_failure_func_name);
 | |
|     if (!WarningFun) {
 | |
|       WarningFun =
 | |
|           Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage,
 | |
|                            cfi_failure_func_name, &M);
 | |
|     }
 | |
| 
 | |
|     BasicBlock *Entry =
 | |
|         BasicBlock::Create(M.getContext(), "entry", WarningFun, 0);
 | |
|     ReturnInst::Create(M.getContext(), Entry);
 | |
|   }
 | |
| }
 | |
| 
 | |
| void ForwardControlFlowIntegrity::rewriteFunctionPointer(
 | |
|     Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart,
 | |
|     Constant *JumpTableMask, Constant *JumpTableSize) {
 | |
|   IRBuilder<> TempBuilder(I);
 | |
| 
 | |
|   Type *OrigFunType = FunPtr->getType();
 | |
| 
 | |
|   BasicBlock *CurBB = cast<BasicBlock>(I->getParent());
 | |
|   Function *CurF = cast<Function>(CurBB->getParent());
 | |
|   Type *Int64Ty = Type::getInt64Ty(M.getContext());
 | |
| 
 | |
|   Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty);
 | |
|   Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty);
 | |
| 
 | |
|   Value *NewFunPtr = nullptr;
 | |
|   Value *Check = nullptr;
 | |
|   switch (CFIType) {
 | |
|   case CFIntegrity::Sub: {
 | |
|     // This is the subtract, mask, and add version.
 | |
|     // Subtract from the base.
 | |
|     Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
 | |
| 
 | |
|     // Mask the difference to force this to be a table offset.
 | |
|     Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask);
 | |
| 
 | |
|     // Add it back to the base.
 | |
|     Value *Result = TempBuilder.CreateAdd(And, TStartInt);
 | |
| 
 | |
|     // Convert it back into a function pointer that we can call.
 | |
|     NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
 | |
|     break;
 | |
|   }
 | |
|   case CFIntegrity::Ror: {
 | |
|     // This is the subtract and rotate version.
 | |
|     // Rotate right by the alignment value. The optimizer should recognize
 | |
|     // this sequence as a rotation.
 | |
| 
 | |
|     // This cast is safe, since unsigned is always a subset of uint64_t.
 | |
|     uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment);
 | |
|     Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64);
 | |
|     Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64);
 | |
| 
 | |
|     // Subtract from the base.
 | |
|     Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
 | |
| 
 | |
|     // Create the equivalent of a rotate-right instruction.
 | |
|     Value *Shr = TempBuilder.CreateLShr(Sub, RightShift);
 | |
|     Value *Shl = TempBuilder.CreateShl(Sub, LeftShift);
 | |
|     Value *Or = TempBuilder.CreateOr(Shr, Shl);
 | |
| 
 | |
|     // Perform unsigned comparison to check for inclusion in the table.
 | |
|     Check = TempBuilder.CreateICmpULT(Or, JumpTableSize);
 | |
|     NewFunPtr = FunPtr;
 | |
|     break;
 | |
|   }
 | |
|   case CFIntegrity::Add: {
 | |
|     // This is the mask and add version.
 | |
|     // Mask the function pointer to turn it into an offset into the table.
 | |
|     Value *And = TempBuilder.CreateAnd(TI, JumpTableMask);
 | |
| 
 | |
|     // Then or this offset to the base and get the pointer value.
 | |
|     Value *Result = TempBuilder.CreateAdd(And, TStartInt);
 | |
| 
 | |
|     // Convert it back into a function pointer that we can call.
 | |
|     NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
 | |
|     break;
 | |
|   }
 | |
|   }
 | |
| 
 | |
|   if (!CFIEnforcing) {
 | |
|     // If a check hasn't been added (in the rotation version), then check to see
 | |
|     // if it's the same as the original function. This check determines whether
 | |
|     // or not we call the CFI failure function.
 | |
|     if (!Check)
 | |
|       Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr);
 | |
|     BasicBlock *InvalidPtrBlock =
 | |
|         BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0);
 | |
|     BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I);
 | |
| 
 | |
|     // Remove the unconditional branch that connects the two blocks.
 | |
|     TerminatorInst *TermInst = CurBB->getTerminator();
 | |
|     TermInst->eraseFromParent();
 | |
| 
 | |
|     // Add a conditional branch that depends on the Check above.
 | |
|     BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB);
 | |
| 
 | |
|     // Call the warning function for this pointer, then continue.
 | |
|     Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock);
 | |
|     insertWarning(M, InvalidPtrBlock, BI, FunPtr);
 | |
|   } else {
 | |
|     // Modify the instruction to call this value.
 | |
|     CallSite CS(I);
 | |
|     CS.setCalledFunction(NewFunPtr);
 | |
|   }
 | |
| }
 | |
| 
 | |
| void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block,
 | |
|                                                 Instruction *I, Value *FunPtr) {
 | |
|   Function *ParentFun = cast<Function>(Block->getParent());
 | |
| 
 | |
|   // Get the function to call right before the instruction.
 | |
|   Function *WarningFun = nullptr;
 | |
|   if (CFIFuncName.empty()) {
 | |
|     WarningFun = M.getFunction(cfi_failure_func_name);
 | |
|   } else {
 | |
|     WarningFun = M.getFunction(CFIFuncName);
 | |
|   }
 | |
| 
 | |
|   assert(WarningFun && "Could not find the CFI failure function");
 | |
| 
 | |
|   Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
 | |
| 
 | |
|   IRBuilder<> WarningInserter(I);
 | |
|   // Create a mergeable GlobalVariable containing the name of the function.
 | |
|   Value *ParentNameGV =
 | |
|       WarningInserter.CreateGlobalString(ParentFun->getName());
 | |
|   Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy);
 | |
|   Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy);
 | |
|   WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr);
 | |
| }
 |