diff --git a/lib/Transforms/IPO/SimplifyLibCalls.cpp b/lib/Transforms/IPO/SimplifyLibCalls.cpp index c7038ccd798..0f19740e2de 100644 --- a/lib/Transforms/IPO/SimplifyLibCalls.cpp +++ b/lib/Transforms/IPO/SimplifyLibCalls.cpp @@ -688,82 +688,72 @@ public: "Number of 'strncmp' calls simplified") {} /// @brief Make sure that the "strncmp" function has the right prototype - virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){ - if (f->getReturnType() == Type::Int32Ty && f->arg_size() == 3) - return true; + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getReturnType() == Type::Int32Ty && FT->getNumParams() == 3 && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == PointerType::get(Type::Int8Ty) && + isa(FT->getParamType(2)); return false; } - /// @brief Perform the strncpy optimization - virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) { + /// @brief Perform the strncmp optimization + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { // First, check to see if src and destination are the same. If they are, // then the optimization is to replace the CallInst with a constant 0 // because the call is a no-op. - Value* s1 = ci->getOperand(1); - Value* s2 = ci->getOperand(2); - if (s1 == s2) { - // strncmp(x,x,l) -> 0 - ci->replaceAllUsesWith(ConstantInt::get(Type::Int32Ty,0)); - ci->eraseFromParent(); + Value *Str1P = CI->getOperand(1); + Value *Str2P = CI->getOperand(2); + if (Str1P == Str2P) { + // strcmp(x,x) -> 0 + CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0)); + CI->eraseFromParent(); return true; } - + // Check the length argument, if it is Constant zero then the strings are // considered equal. - uint64_t len_arg = 0; - bool len_arg_is_const = false; - if (ConstantInt* len_CI = dyn_cast(ci->getOperand(3))) { - len_arg_is_const = true; - len_arg = len_CI->getZExtValue(); - if (len_arg == 0) { - // strncmp(x,y,0) -> 0 - ci->replaceAllUsesWith(ConstantInt::get(Type::Int32Ty,0)); - ci->eraseFromParent(); - return true; - } + ConstantInt *LengthArg = dyn_cast(CI->getOperand(3)); + if (LengthArg && LengthArg->isZero()) { + // strncmp(x,y,0) -> 0 + CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0)); + CI->eraseFromParent(); + return true; } - - bool isstr_1 = false; - uint64_t len_1 = 0, StartIdx; - ConstantArray* A1; - if (GetConstantStringInfo(s1, A1, len_1, StartIdx)) { - isstr_1 = true; - if (len_1 == 0) { - // strncmp("",x) -> *x - LoadInst* load = new LoadInst(s1,ci->getName()+".load",ci); - CastInst* cast = - CastInst::create(Instruction::SExt, load, Type::Int32Ty, - ci->getName()+".int", ci); - ci->replaceAllUsesWith(cast); - ci->eraseFromParent(); - return true; - } + + uint64_t Str1Len, Str1StartIdx; + ConstantArray *A1; + bool Str1IsCst = GetConstantStringInfo(Str1P, A1, Str1Len, Str1StartIdx); + if (Str1IsCst && Str1Len == 0) { + // strcmp("", x) -> *x + Value *V = new LoadInst(Str2P, CI->getName()+".load", CI); + V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI); + CI->replaceAllUsesWith(V); + CI->eraseFromParent(); + return true; } - - bool isstr_2 = false; - uint64_t len_2 = 0; + + uint64_t Str2Len, Str2StartIdx; ConstantArray* A2; - if (GetConstantStringInfo(s2, A2, len_2, StartIdx)) { - isstr_2 = true; - if (len_2 == 0) { - // strncmp(x,"") -> *x - LoadInst* load = new LoadInst(s2,ci->getName()+".val",ci); - CastInst* cast = - CastInst::create(Instruction::SExt, load, Type::Int32Ty, - ci->getName()+".int", ci); - ci->replaceAllUsesWith(cast); - ci->eraseFromParent(); - return true; - } + bool Str2IsCst = GetConstantStringInfo(Str2P, A2, Str2Len, Str2StartIdx); + if (Str2IsCst && Str2Len == 0) { + // strcmp(x,"") -> *x + Value *V = new LoadInst(Str1P, CI->getName()+".load", CI); + V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI); + CI->replaceAllUsesWith(V); + CI->eraseFromParent(); + return true; } - - if (isstr_1 && isstr_2 && len_arg_is_const) { - // strncmp(x,y,const) -> constant - std::string str1 = A1->getAsString(); - std::string str2 = A2->getAsString(); - int result = strncmp(str1.c_str(), str2.c_str(), len_arg); - ci->replaceAllUsesWith(ConstantInt::get(Type::Int32Ty,result)); - ci->eraseFromParent(); + + if (LengthArg && Str1IsCst && Str2IsCst && A1->isCString() && + A2->isCString()) { + // strcmp(x, y) -> cnst (if both x and y are constant strings) + std::string S1 = A1->getAsString(); + std::string S2 = A2->getAsString(); + int R = strncmp(S1.c_str()+Str1StartIdx, S2.c_str()+Str2StartIdx, + LengthArg->getZExtValue()); + CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), R)); + CI->eraseFromParent(); return true; } return false;