From 8b0f0cb9088a7746fea2ba23821e50d87cef4a56 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Mon, 12 Jan 2004 21:02:29 +0000 Subject: [PATCH] Remove a whole bunch more ugliness. This is actually getting to the point of this whole refactoring: allow constant folding methods to return something other than predefined classes, allow them to return generic Constant*'s. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@10806 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/VMCore/ConstantFold.cpp | 118 ++++++++++++++++++++--------------- lib/VMCore/ConstantFold.h | 60 ++++++------------ lib/VMCore/ConstantFolding.h | 60 ++++++------------ 3 files changed, 106 insertions(+), 132 deletions(-) diff --git a/lib/VMCore/ConstantFold.cpp b/lib/VMCore/ConstantFold.cpp index 45b021bb7dc..ddec284942a 100644 --- a/lib/VMCore/ConstantFold.cpp +++ b/lib/VMCore/ConstantFold.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "ConstantHandling.h" +#include "llvm/Constants.h" #include "llvm/iPHINode.h" #include "llvm/InstrTypes.h" #include "llvm/DerivedTypes.h" @@ -57,7 +58,24 @@ Constant *llvm::ConstantFoldCastInstruction(const Constant *V, return ConstantExpr::getCast(CE->getOperand(0), DestTy); } - return ConstRules::get(V, V).castTo(V, DestTy); + ConstRules &Rules = ConstRules::get(V, V); + + switch (DestTy->getPrimitiveID()) { + case Type::BoolTyID: return Rules.castToBool(V); + case Type::UByteTyID: return Rules.castToUByte(V); + case Type::SByteTyID: return Rules.castToSByte(V); + case Type::UShortTyID: return Rules.castToUShort(V); + case Type::ShortTyID: return Rules.castToShort(V); + case Type::UIntTyID: return Rules.castToUInt(V); + case Type::IntTyID: return Rules.castToInt(V); + case Type::ULongTyID: return Rules.castToULong(V); + case Type::LongTyID: return Rules.castToLong(V); + case Type::FloatTyID: return Rules.castToFloat(V); + case Type::DoubleTyID: return Rules.castToDouble(V); + case Type::PointerTyID: + return Rules.castToPointer(V, cast(DestTy)); + default: return 0; + } } Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, @@ -209,47 +227,45 @@ class TemplateRules : public ConstRules { return SubClassName::Shr((const ArgType *)V1, (const ArgType *)V2); } - virtual ConstantBool *lessthan(const Constant *V1, - const Constant *V2) const { + virtual Constant *lessthan(const Constant *V1, const Constant *V2) const { return SubClassName::LessThan((const ArgType *)V1, (const ArgType *)V2); } - virtual ConstantBool *equalto(const Constant *V1, - const Constant *V2) const { + virtual Constant *equalto(const Constant *V1, const Constant *V2) const { return SubClassName::EqualTo((const ArgType *)V1, (const ArgType *)V2); } // Casting operators. ick - virtual ConstantBool *castToBool(const Constant *V) const { + virtual Constant *castToBool(const Constant *V) const { return SubClassName::CastToBool((const ArgType*)V); } - virtual ConstantSInt *castToSByte(const Constant *V) const { + virtual Constant *castToSByte(const Constant *V) const { return SubClassName::CastToSByte((const ArgType*)V); } - virtual ConstantUInt *castToUByte(const Constant *V) const { + virtual Constant *castToUByte(const Constant *V) const { return SubClassName::CastToUByte((const ArgType*)V); } - virtual ConstantSInt *castToShort(const Constant *V) const { + virtual Constant *castToShort(const Constant *V) const { return SubClassName::CastToShort((const ArgType*)V); } - virtual ConstantUInt *castToUShort(const Constant *V) const { + virtual Constant *castToUShort(const Constant *V) const { return SubClassName::CastToUShort((const ArgType*)V); } - virtual ConstantSInt *castToInt(const Constant *V) const { + virtual Constant *castToInt(const Constant *V) const { return SubClassName::CastToInt((const ArgType*)V); } - virtual ConstantUInt *castToUInt(const Constant *V) const { + virtual Constant *castToUInt(const Constant *V) const { return SubClassName::CastToUInt((const ArgType*)V); } - virtual ConstantSInt *castToLong(const Constant *V) const { + virtual Constant *castToLong(const Constant *V) const { return SubClassName::CastToLong((const ArgType*)V); } - virtual ConstantUInt *castToULong(const Constant *V) const { + virtual Constant *castToULong(const Constant *V) const { return SubClassName::CastToULong((const ArgType*)V); } - virtual ConstantFP *castToFloat(const Constant *V) const { + virtual Constant *castToFloat(const Constant *V) const { return SubClassName::CastToFloat((const ArgType*)V); } - virtual ConstantFP *castToDouble(const Constant *V) const { + virtual Constant *castToDouble(const Constant *V) const { return SubClassName::CastToDouble((const ArgType*)V); } virtual Constant *castToPointer(const Constant *V, @@ -271,27 +287,27 @@ class TemplateRules : public ConstRules { static Constant *Xor(const ArgType *V1, const ArgType *V2) { return 0; } static Constant *Shl(const ArgType *V1, const ArgType *V2) { return 0; } static Constant *Shr(const ArgType *V1, const ArgType *V2) { return 0; } - static ConstantBool *LessThan(const ArgType *V1, const ArgType *V2) { + static Constant *LessThan(const ArgType *V1, const ArgType *V2) { return 0; } - static ConstantBool *EqualTo(const ArgType *V1, const ArgType *V2) { + static Constant *EqualTo(const ArgType *V1, const ArgType *V2) { return 0; } // Casting operators. ick - static ConstantBool *CastToBool (const Constant *V) { return 0; } - static ConstantSInt *CastToSByte (const Constant *V) { return 0; } - static ConstantUInt *CastToUByte (const Constant *V) { return 0; } - static ConstantSInt *CastToShort (const Constant *V) { return 0; } - static ConstantUInt *CastToUShort(const Constant *V) { return 0; } - static ConstantSInt *CastToInt (const Constant *V) { return 0; } - static ConstantUInt *CastToUInt (const Constant *V) { return 0; } - static ConstantSInt *CastToLong (const Constant *V) { return 0; } - static ConstantUInt *CastToULong (const Constant *V) { return 0; } - static ConstantFP *CastToFloat (const Constant *V) { return 0; } - static ConstantFP *CastToDouble(const Constant *V) { return 0; } - static Constant *CastToPointer(const Constant *, - const PointerType *) {return 0;} + static Constant *CastToBool (const Constant *V) { return 0; } + static Constant *CastToSByte (const Constant *V) { return 0; } + static Constant *CastToUByte (const Constant *V) { return 0; } + static Constant *CastToShort (const Constant *V) { return 0; } + static Constant *CastToUShort(const Constant *V) { return 0; } + static Constant *CastToInt (const Constant *V) { return 0; } + static Constant *CastToUInt (const Constant *V) { return 0; } + static Constant *CastToLong (const Constant *V) { return 0; } + static Constant *CastToULong (const Constant *V) { return 0; } + static Constant *CastToFloat (const Constant *V) { return 0; } + static Constant *CastToDouble(const Constant *V) { return 0; } + static Constant *CastToPointer(const Constant *, + const PointerType *) {return 0;} }; @@ -303,7 +319,7 @@ class TemplateRules : public ConstRules { // EmptyRules provides a concrete base class of ConstRules that does nothing // struct EmptyRules : public TemplateRules { - static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) { + static Constant *EqualTo(const Constant *V1, const Constant *V2) { if (V1 == V2) return ConstantBool::True; return 0; } @@ -319,11 +335,11 @@ struct EmptyRules : public TemplateRules { // struct BoolRules : public TemplateRules { - static ConstantBool *LessThan(const ConstantBool *V1, const ConstantBool *V2){ + static Constant *LessThan(const ConstantBool *V1, const ConstantBool *V2){ return ConstantBool::get(V1->getValue() < V2->getValue()); } - static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) { + static Constant *EqualTo(const Constant *V1, const Constant *V2) { return ConstantBool::get(V1 == V2); } @@ -341,7 +357,7 @@ struct BoolRules : public TemplateRules { // Casting operators. ick #define DEF_CAST(TYPE, CLASS, CTYPE) \ - static CLASS *CastTo##TYPE (const ConstantBool *V) { \ + static Constant *CastTo##TYPE (const ConstantBool *V) { \ return CLASS::get(Type::TYPE##Ty, (CTYPE)(bool)V->getValue()); \ } @@ -369,40 +385,40 @@ struct BoolRules : public TemplateRules { // struct NullPointerRules : public TemplateRules { - static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) { + static Constant *EqualTo(const Constant *V1, const Constant *V2) { return ConstantBool::True; // Null pointers are always equal } - static ConstantBool *CastToBool (const Constant *V) { + static Constant *CastToBool(const Constant *V) { return ConstantBool::False; } - static ConstantSInt *CastToSByte (const Constant *V) { + static Constant *CastToSByte (const Constant *V) { return ConstantSInt::get(Type::SByteTy, 0); } - static ConstantUInt *CastToUByte (const Constant *V) { + static Constant *CastToUByte (const Constant *V) { return ConstantUInt::get(Type::UByteTy, 0); } - static ConstantSInt *CastToShort (const Constant *V) { + static Constant *CastToShort (const Constant *V) { return ConstantSInt::get(Type::ShortTy, 0); } - static ConstantUInt *CastToUShort(const Constant *V) { + static Constant *CastToUShort(const Constant *V) { return ConstantUInt::get(Type::UShortTy, 0); } - static ConstantSInt *CastToInt (const Constant *V) { + static Constant *CastToInt (const Constant *V) { return ConstantSInt::get(Type::IntTy, 0); } - static ConstantUInt *CastToUInt (const Constant *V) { + static Constant *CastToUInt (const Constant *V) { return ConstantUInt::get(Type::UIntTy, 0); } - static ConstantSInt *CastToLong (const Constant *V) { + static Constant *CastToLong (const Constant *V) { return ConstantSInt::get(Type::LongTy, 0); } - static ConstantUInt *CastToULong (const Constant *V) { + static Constant *CastToULong (const Constant *V) { return ConstantUInt::get(Type::ULongTy, 0); } - static ConstantFP *CastToFloat (const Constant *V) { + static Constant *CastToFloat (const Constant *V) { return ConstantFP::get(Type::FloatTy, 0); } - static ConstantFP *CastToDouble(const Constant *V) { + static Constant *CastToDouble(const Constant *V) { return ConstantFP::get(Type::DoubleTy, 0); } @@ -444,14 +460,12 @@ struct DirectRules : public TemplateRules { return ConstantClass::get(*Ty, R); } - static ConstantBool *LessThan(const ConstantClass *V1, - const ConstantClass *V2) { + static Constant *LessThan(const ConstantClass *V1, const ConstantClass *V2) { bool R = (BuiltinType)V1->getValue() < (BuiltinType)V2->getValue(); return ConstantBool::get(R); } - static ConstantBool *EqualTo(const ConstantClass *V1, - const ConstantClass *V2) { + static Constant *EqualTo(const ConstantClass *V1, const ConstantClass *V2) { bool R = (BuiltinType)V1->getValue() == (BuiltinType)V2->getValue(); return ConstantBool::get(R); } @@ -465,7 +479,7 @@ struct DirectRules : public TemplateRules { // Casting operators. ick #define DEF_CAST(TYPE, CLASS, CTYPE) \ - static CLASS *CastTo##TYPE (const ConstantClass *V) { \ + static Constant *CastTo##TYPE (const ConstantClass *V) { \ return CLASS::get(Type::TYPE##Ty, (CTYPE)(BuiltinType)V->getValue()); \ } diff --git a/lib/VMCore/ConstantFold.h b/lib/VMCore/ConstantFold.h index 8475e44f908..dc5d0cfbfae 100644 --- a/lib/VMCore/ConstantFold.h +++ b/lib/VMCore/ConstantFold.h @@ -15,12 +15,12 @@ #ifndef CONSTANTHANDLING_H #define CONSTANTHANDLING_H -#include "llvm/Constants.h" -#include "llvm/Type.h" +#include namespace llvm { - -class PointerType; + class Constant; + class Type; + class PointerType; struct ConstRules { ConstRules() {} @@ -37,44 +37,24 @@ struct ConstRules { virtual Constant *shl(const Constant *V1, const Constant *V2) const = 0; virtual Constant *shr(const Constant *V1, const Constant *V2) const = 0; - virtual ConstantBool *lessthan(const Constant *V1, - const Constant *V2) const = 0; - virtual ConstantBool *equalto(const Constant *V1, - const Constant *V2) const = 0; + virtual Constant *lessthan(const Constant *V1, const Constant *V2) const = 0; + + virtual Constant *equalto(const Constant *V1, const Constant *V2) const = 0; // Casting operators. ick - virtual ConstantBool *castToBool (const Constant *V) const = 0; - virtual ConstantSInt *castToSByte (const Constant *V) const = 0; - virtual ConstantUInt *castToUByte (const Constant *V) const = 0; - virtual ConstantSInt *castToShort (const Constant *V) const = 0; - virtual ConstantUInt *castToUShort(const Constant *V) const = 0; - virtual ConstantSInt *castToInt (const Constant *V) const = 0; - virtual ConstantUInt *castToUInt (const Constant *V) const = 0; - virtual ConstantSInt *castToLong (const Constant *V) const = 0; - virtual ConstantUInt *castToULong (const Constant *V) const = 0; - virtual ConstantFP *castToFloat (const Constant *V) const = 0; - virtual ConstantFP *castToDouble(const Constant *V) const = 0; - virtual Constant *castToPointer(const Constant *V, - const PointerType *Ty) const = 0; - - inline Constant *castTo(const Constant *V, const Type *Ty) const { - switch (Ty->getPrimitiveID()) { - case Type::BoolTyID: return castToBool(V); - case Type::UByteTyID: return castToUByte(V); - case Type::SByteTyID: return castToSByte(V); - case Type::UShortTyID: return castToUShort(V); - case Type::ShortTyID: return castToShort(V); - case Type::UIntTyID: return castToUInt(V); - case Type::IntTyID: return castToInt(V); - case Type::ULongTyID: return castToULong(V); - case Type::LongTyID: return castToLong(V); - case Type::FloatTyID: return castToFloat(V); - case Type::DoubleTyID: return castToDouble(V); - case Type::PointerTyID: - return castToPointer(V, reinterpret_cast(Ty)); - default: return 0; - } - } + virtual Constant *castToBool (const Constant *V) const = 0; + virtual Constant *castToSByte (const Constant *V) const = 0; + virtual Constant *castToUByte (const Constant *V) const = 0; + virtual Constant *castToShort (const Constant *V) const = 0; + virtual Constant *castToUShort(const Constant *V) const = 0; + virtual Constant *castToInt (const Constant *V) const = 0; + virtual Constant *castToUInt (const Constant *V) const = 0; + virtual Constant *castToLong (const Constant *V) const = 0; + virtual Constant *castToULong (const Constant *V) const = 0; + virtual Constant *castToFloat (const Constant *V) const = 0; + virtual Constant *castToDouble(const Constant *V) const = 0; + virtual Constant *castToPointer(const Constant *V, + const PointerType *Ty) const = 0; // ConstRules::get - Return an instance of ConstRules for the specified // constant operands. diff --git a/lib/VMCore/ConstantFolding.h b/lib/VMCore/ConstantFolding.h index 8475e44f908..dc5d0cfbfae 100644 --- a/lib/VMCore/ConstantFolding.h +++ b/lib/VMCore/ConstantFolding.h @@ -15,12 +15,12 @@ #ifndef CONSTANTHANDLING_H #define CONSTANTHANDLING_H -#include "llvm/Constants.h" -#include "llvm/Type.h" +#include namespace llvm { - -class PointerType; + class Constant; + class Type; + class PointerType; struct ConstRules { ConstRules() {} @@ -37,44 +37,24 @@ struct ConstRules { virtual Constant *shl(const Constant *V1, const Constant *V2) const = 0; virtual Constant *shr(const Constant *V1, const Constant *V2) const = 0; - virtual ConstantBool *lessthan(const Constant *V1, - const Constant *V2) const = 0; - virtual ConstantBool *equalto(const Constant *V1, - const Constant *V2) const = 0; + virtual Constant *lessthan(const Constant *V1, const Constant *V2) const = 0; + + virtual Constant *equalto(const Constant *V1, const Constant *V2) const = 0; // Casting operators. ick - virtual ConstantBool *castToBool (const Constant *V) const = 0; - virtual ConstantSInt *castToSByte (const Constant *V) const = 0; - virtual ConstantUInt *castToUByte (const Constant *V) const = 0; - virtual ConstantSInt *castToShort (const Constant *V) const = 0; - virtual ConstantUInt *castToUShort(const Constant *V) const = 0; - virtual ConstantSInt *castToInt (const Constant *V) const = 0; - virtual ConstantUInt *castToUInt (const Constant *V) const = 0; - virtual ConstantSInt *castToLong (const Constant *V) const = 0; - virtual ConstantUInt *castToULong (const Constant *V) const = 0; - virtual ConstantFP *castToFloat (const Constant *V) const = 0; - virtual ConstantFP *castToDouble(const Constant *V) const = 0; - virtual Constant *castToPointer(const Constant *V, - const PointerType *Ty) const = 0; - - inline Constant *castTo(const Constant *V, const Type *Ty) const { - switch (Ty->getPrimitiveID()) { - case Type::BoolTyID: return castToBool(V); - case Type::UByteTyID: return castToUByte(V); - case Type::SByteTyID: return castToSByte(V); - case Type::UShortTyID: return castToUShort(V); - case Type::ShortTyID: return castToShort(V); - case Type::UIntTyID: return castToUInt(V); - case Type::IntTyID: return castToInt(V); - case Type::ULongTyID: return castToULong(V); - case Type::LongTyID: return castToLong(V); - case Type::FloatTyID: return castToFloat(V); - case Type::DoubleTyID: return castToDouble(V); - case Type::PointerTyID: - return castToPointer(V, reinterpret_cast(Ty)); - default: return 0; - } - } + virtual Constant *castToBool (const Constant *V) const = 0; + virtual Constant *castToSByte (const Constant *V) const = 0; + virtual Constant *castToUByte (const Constant *V) const = 0; + virtual Constant *castToShort (const Constant *V) const = 0; + virtual Constant *castToUShort(const Constant *V) const = 0; + virtual Constant *castToInt (const Constant *V) const = 0; + virtual Constant *castToUInt (const Constant *V) const = 0; + virtual Constant *castToLong (const Constant *V) const = 0; + virtual Constant *castToULong (const Constant *V) const = 0; + virtual Constant *castToFloat (const Constant *V) const = 0; + virtual Constant *castToDouble(const Constant *V) const = 0; + virtual Constant *castToPointer(const Constant *V, + const PointerType *Ty) const = 0; // ConstRules::get - Return an instance of ConstRules for the specified // constant operands.