diff --git a/include/llvm/Support/ErrorOr.h b/include/llvm/Support/ErrorOr.h index 828d77b852e..ceec33d1852 100644 --- a/include/llvm/Support/ErrorOr.h +++ b/include/llvm/Support/ErrorOr.h @@ -162,6 +162,7 @@ public: /// T cannot be a rvalue reference. template class ErrorOr { + template friend class ErrorOr; static const bool isRef = is_reference::value; typedef ReferenceStorage::type> wrap; @@ -199,60 +200,43 @@ public: } ErrorOr(const ErrorOr &Other) : IsValid(false) { - // Construct an invalid ErrorOr if other is invalid. - if (!Other.IsValid) - return; - IsValid = true; - if (!Other.HasError) { - // Get the other value. - HasError = false; - new (get()) storage_type(*Other.get()); - } else { - // Get other's error. - Error = Other.Error; - HasError = true; - Error->aquire(); - } + copyConstruct(Other); + } + + template + ErrorOr(const ErrorOr &Other) : IsValid(false) { + copyConstruct(Other); } ErrorOr &operator =(const ErrorOr &Other) { - if (this == &Other) - return *this; - - this->~ErrorOr(); - new (this) ErrorOr(Other); + copyAssign(Other); + return *this; + } + template + ErrorOr &operator =(const ErrorOr &Other) { + copyAssign(Other); return *this; } #if LLVM_HAS_RVALUE_REFERENCES ErrorOr(ErrorOr &&Other) : IsValid(false) { - // Construct an invalid ErrorOr if other is invalid. - if (!Other.IsValid) - return; - IsValid = true; - if (!Other.HasError) { - // Get the other value. - HasError = false; - new (get()) storage_type(std::move(*Other.get())); - // Tell other not to do any destruction. - Other.IsValid = false; - } else { - // Get other's error. - Error = Other.Error; - HasError = true; - // Tell other not to do any destruction. - Other.IsValid = false; - } + moveConstruct(std::move(Other)); + } + + template + ErrorOr(ErrorOr &&Other) : IsValid(false) { + moveConstruct(std::move(Other)); } ErrorOr &operator =(ErrorOr &&Other) { - if (this == &Other) - return *this; - - this->~ErrorOr(); - new (this) ErrorOr(std::move(Other)); + moveAssign(std::move(Other)); + return *this; + } + template + ErrorOr &operator =(ErrorOr &&Other) { + moveAssign(std::move(Other)); return *this; } #endif @@ -300,6 +284,75 @@ public: } private: + template + void copyConstruct(const ErrorOr &Other) { + // Construct an invalid ErrorOr if other is invalid. + if (!Other.IsValid) + return; + IsValid = true; + if (!Other.HasError) { + // Get the other value. + HasError = false; + new (get()) storage_type(*Other.get()); + } else { + // Get other's error. + Error = Other.Error; + HasError = true; + Error->aquire(); + } + } + + template + static bool compareThisIfSameType(const T1 &a, const T1 &b) { + return &a == &b; + } + + template + static bool compareThisIfSameType(const T1 &a, const T2 &b) { + return false; + } + + template + void copyAssign(const ErrorOr &Other) { + if (compareThisIfSameType(*this, Other)) + return; + + this->~ErrorOr(); + new (this) ErrorOr(Other); + } + +#if LLVM_HAS_RVALUE_REFERENCES + template + void moveConstruct(ErrorOr &&Other) { + // Construct an invalid ErrorOr if other is invalid. + if (!Other.IsValid) + return; + IsValid = true; + if (!Other.HasError) { + // Get the other value. + HasError = false; + new (get()) storage_type(std::move(*Other.get())); + // Tell other not to do any destruction. + Other.IsValid = false; + } else { + // Get other's error. + Error = Other.Error; + HasError = true; + // Tell other not to do any destruction. + Other.IsValid = false; + } + } + + template + void moveAssign(ErrorOr &&Other) { + if (compareThisIfSameType(*this, Other)) + return; + + this->~ErrorOr(); + new (this) ErrorOr(std::move(Other)); + } +#endif + pointer toPointer(pointer Val) { return Val; } @@ -308,7 +361,6 @@ private: return &Val->get(); } -protected: storage_type *get() { assert(IsValid && "Can't do anything on a default constructed ErrorOr!"); assert(!HasError && "Cannot get value when an error exists!"); diff --git a/unittests/Support/ErrorOrTest.cpp b/unittests/Support/ErrorOrTest.cpp index a8608860b84..aa0ddd5e79c 100644 --- a/unittests/Support/ErrorOrTest.cpp +++ b/unittests/Support/ErrorOrTest.cpp @@ -53,6 +53,19 @@ TEST(ErrorOr, Types) { EXPECT_EQ(3, **t3()); #endif } + +struct B {}; +struct D : B {}; + +TEST(ErrorOr, Covariant) { + ErrorOr b(ErrorOr(0)); + b = ErrorOr(0); + +#if LLVM_HAS_CXX11_STDLIB + ErrorOr > b1(ErrorOr >(0)); + b1 = ErrorOr >(0); +#endif +} } // end anon namespace struct InvalidArgError {