From 02fb8d423d1d0eb5a05f99ec04fb085a21facf92 Mon Sep 17 00:00:00 2001 From: Evan Harvey Date: Wed, 10 May 2023 13:07:02 -0600 Subject: [PATCH] core/src: Move floating_point_wrapper to private header --- core/src/Kokkos_Half.hpp | 997 +--------------- .../impl/Kokkos_Half_FloatingPointWrapper.hpp | 1016 +++++++++++++++++ 2 files changed, 1017 insertions(+), 996 deletions(-) create mode 100644 core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp diff --git a/core/src/Kokkos_Half.hpp b/core/src/Kokkos_Half.hpp index 179141220f..91b94b4cfa 100644 --- a/core/src/Kokkos_Half.hpp +++ b/core/src/Kokkos_Half.hpp @@ -21,1002 +21,7 @@ #define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF #endif -#include - -#include -#include // istream & ostream for extraction and insertion ops -#include - -#ifdef KOKKOS_IMPL_HALF_TYPE_DEFINED - -// KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH: A macro to select which -// floating_pointer_wrapper operator paths should be used. For CUDA, let the -// compiler conditionally select when device ops are used For SYCL, we have a -// full half type on both host and device -#if defined(__CUDA_ARCH__) || defined(KOKKOS_ENABLE_SYCL) -#define KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH -#endif - -/************************* BEGIN forward declarations *************************/ -namespace Kokkos { -namespace Experimental { -namespace Impl { -template -class floating_point_wrapper; -} - -// Declare half_t (binary16) -using half_t = Kokkos::Experimental::Impl::floating_point_wrapper< - Kokkos::Impl::half_impl_t ::type>; -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(float val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(bool val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(double val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(short val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(int val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(long val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(long long val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(unsigned short val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(unsigned int val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(unsigned long val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(unsigned long long val); -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(half_t); - -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION - std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION - std::enable_if_t::value, T> - cast_from_half(half_t); -template -KOKKOS_INLINE_FUNCTION - std::enable_if_t::value, T> - cast_from_half(half_t); - -// declare bhalf_t -#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED -using bhalf_t = Kokkos::Experimental::Impl::floating_point_wrapper< - Kokkos::Impl ::bhalf_impl_t ::type>; - -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(float val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(bool val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(double val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(short val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(int val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(long val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(long long val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(unsigned short val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(unsigned int val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(unsigned long val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(unsigned long long val); -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(bhalf_t val); - -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION - std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION - std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -template -KOKKOS_INLINE_FUNCTION - std::enable_if_t::value, T> - cast_from_bhalf(bhalf_t); -#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED - -template -static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::half_t cast_to_wrapper( - T x, const volatile Kokkos::Impl::half_impl_t::type&); - -#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED -template -static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::bhalf_t cast_to_wrapper( - T x, const volatile Kokkos::Impl::bhalf_impl_t::type&); -#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED - -template -static KOKKOS_INLINE_FUNCTION T -cast_from_wrapper(const Kokkos::Experimental::half_t& x); - -#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED -template -static KOKKOS_INLINE_FUNCTION T -cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x); -#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED -/************************** END forward declarations **************************/ - -namespace Impl { -template -class alignas(FloatType) floating_point_wrapper { - public: - using impl_type = FloatType; - - private: - impl_type val; - using fixed_width_integer_type = std::conditional_t< - sizeof(impl_type) == 2, uint16_t, - std::conditional_t< - sizeof(impl_type) == 4, uint32_t, - std::conditional_t>>; - static_assert(!std::is_void::value, - "Invalid impl_type"); - - public: - // In-class initialization and defaulted default constructors not used - // since Cuda supports half precision initialization via the below constructor - KOKKOS_FUNCTION - floating_point_wrapper() : val(0.0F) {} - -// Copy constructors -// Getting "C2580: multiple versions of a defaulted special -// member function are not allowed" with VS 16.11.3 and CUDA 11.4.2 -#if defined(_WIN32) && defined(KOKKOS_ENABLE_CUDA) - KOKKOS_FUNCTION - floating_point_wrapper(const floating_point_wrapper& rhs) : val(rhs.val) {} - - KOKKOS_FUNCTION - floating_point_wrapper& operator=(const floating_point_wrapper& rhs) { - val = rhs.val; - return *this; - } -#else - KOKKOS_DEFAULTED_FUNCTION - floating_point_wrapper(const floating_point_wrapper&) noexcept = default; - - KOKKOS_DEFAULTED_FUNCTION - floating_point_wrapper& operator=(const floating_point_wrapper&) noexcept = - default; -#endif - - KOKKOS_INLINE_FUNCTION - floating_point_wrapper(const volatile floating_point_wrapper& rhs) { -#if defined(KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH) && !defined(KOKKOS_ENABLE_SYCL) - val = rhs.val; -#else - const volatile fixed_width_integer_type* rv_ptr = - reinterpret_cast(&rhs.val); - const fixed_width_integer_type rv_val = *rv_ptr; - val = reinterpret_cast(rv_val); -#endif // KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - } - - // Don't support implicit conversion back to impl_type. - // impl_type is a storage only type on host. - KOKKOS_FUNCTION - explicit operator impl_type() const { return val; } - KOKKOS_FUNCTION - explicit operator float() const { return cast_from_wrapper(*this); } - KOKKOS_FUNCTION - explicit operator bool() const { return cast_from_wrapper(*this); } - KOKKOS_FUNCTION - explicit operator double() const { return cast_from_wrapper(*this); } - KOKKOS_FUNCTION - explicit operator short() const { return cast_from_wrapper(*this); } - KOKKOS_FUNCTION - explicit operator int() const { return cast_from_wrapper(*this); } - KOKKOS_FUNCTION - explicit operator long() const { return cast_from_wrapper(*this); } - KOKKOS_FUNCTION - explicit operator long long() const { - return cast_from_wrapper(*this); - } - KOKKOS_FUNCTION - explicit operator unsigned short() const { - return cast_from_wrapper(*this); - } - KOKKOS_FUNCTION - explicit operator unsigned int() const { - return cast_from_wrapper(*this); - } - KOKKOS_FUNCTION - explicit operator unsigned long() const { - return cast_from_wrapper(*this); - } - KOKKOS_FUNCTION - explicit operator unsigned long long() const { - return cast_from_wrapper(*this); - } - - /** - * Conversion constructors. - * - * Support implicit conversions from impl_type, float, double -> - * floating_point_wrapper. Mixed precision expressions require upcasting which - * is done in the - * "// Binary Arithmetic" operator overloads below. - * - * Support implicit conversions from integral types -> floating_point_wrapper. - * Expressions involving floating_point_wrapper with integral types require - * downcasting the integral types to floating_point_wrapper. Existing operator - * overloads can handle this with the addition of the below implicit - * conversion constructors. - */ - KOKKOS_FUNCTION - constexpr floating_point_wrapper(impl_type rhs) : val(rhs) {} - KOKKOS_FUNCTION - floating_point_wrapper(float rhs) : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(double rhs) : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - explicit floating_point_wrapper(bool rhs) - : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(short rhs) : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(int rhs) : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(long rhs) : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(long long rhs) : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(unsigned short rhs) - : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(unsigned int rhs) - : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(unsigned long rhs) - : val(cast_to_wrapper(rhs, val).val) {} - KOKKOS_FUNCTION - floating_point_wrapper(unsigned long long rhs) - : val(cast_to_wrapper(rhs, val).val) {} - - // Unary operators - KOKKOS_FUNCTION - floating_point_wrapper operator+() const { - floating_point_wrapper tmp = *this; -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - tmp.val = +tmp.val; -#else - tmp.val = cast_to_wrapper(+cast_from_wrapper(tmp), val).val; -#endif - return tmp; - } - - KOKKOS_FUNCTION - floating_point_wrapper operator-() const { - floating_point_wrapper tmp = *this; -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - tmp.val = -tmp.val; -#else - tmp.val = cast_to_wrapper(-cast_from_wrapper(tmp), val).val; -#endif - return tmp; - } - - // Prefix operators - KOKKOS_FUNCTION - floating_point_wrapper& operator++() { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - val = val + impl_type(1.0F); // cuda has no operator++ for __nv_bfloat -#else - float tmp = cast_from_wrapper(*this); - ++tmp; - val = cast_to_wrapper(tmp, val).val; -#endif - return *this; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator--() { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - val = val - impl_type(1.0F); // cuda has no operator-- for __nv_bfloat -#else - float tmp = cast_from_wrapper(*this); - --tmp; - val = cast_to_wrapper(tmp, val).val; -#endif - return *this; - } - - // Postfix operators - KOKKOS_FUNCTION - floating_point_wrapper operator++(int) { - floating_point_wrapper tmp = *this; - operator++(); - return tmp; - } - - KOKKOS_FUNCTION - floating_point_wrapper operator--(int) { - floating_point_wrapper tmp = *this; - operator--(); - return tmp; - } - - // Binary operators - KOKKOS_FUNCTION - floating_point_wrapper& operator=(impl_type rhs) { - val = rhs; - return *this; - } - - template - KOKKOS_FUNCTION floating_point_wrapper& operator=(T rhs) { - val = cast_to_wrapper(rhs, val).val; - return *this; - } - - template - KOKKOS_FUNCTION void operator=(T rhs) volatile { - impl_type new_val = cast_to_wrapper(rhs, val).val; - volatile fixed_width_integer_type* val_ptr = - reinterpret_cast( - const_cast(&val)); - *val_ptr = reinterpret_cast(new_val); - } - - // Compound operators - KOKKOS_FUNCTION - floating_point_wrapper& operator+=(floating_point_wrapper rhs) { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - val = val + rhs.val; // cuda has no operator+= for __nv_bfloat -#else - val = cast_to_wrapper( - cast_from_wrapper(*this) + cast_from_wrapper(rhs), - val) - .val; -#endif - return *this; - } - - KOKKOS_FUNCTION - void operator+=(const volatile floating_point_wrapper& rhs) volatile { - floating_point_wrapper tmp_rhs = rhs; - floating_point_wrapper tmp_lhs = *this; - - tmp_lhs += tmp_rhs; - *this = tmp_lhs; - } - - // Compound operators: upcast overloads for += - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator+=(T& lhs, floating_point_wrapper rhs) { - lhs += static_cast(rhs); - return lhs; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator+=(float rhs) { - float result = static_cast(val) + rhs; - val = static_cast(result); - return *this; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator+=(double rhs) { - double result = static_cast(val) + rhs; - val = static_cast(result); - return *this; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator-=(floating_point_wrapper rhs) { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - val = val - rhs.val; // cuda has no operator-= for __nv_bfloat -#else - val = cast_to_wrapper( - cast_from_wrapper(*this) - cast_from_wrapper(rhs), - val) - .val; -#endif - return *this; - } - - KOKKOS_FUNCTION - void operator-=(const volatile floating_point_wrapper& rhs) volatile { - floating_point_wrapper tmp_rhs = rhs; - floating_point_wrapper tmp_lhs = *this; - - tmp_lhs -= tmp_rhs; - *this = tmp_lhs; - } - - // Compund operators: upcast overloads for -= - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator-=(T& lhs, floating_point_wrapper rhs) { - lhs -= static_cast(rhs); - return lhs; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator-=(float rhs) { - float result = static_cast(val) - rhs; - val = static_cast(result); - return *this; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator-=(double rhs) { - double result = static_cast(val) - rhs; - val = static_cast(result); - return *this; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator*=(floating_point_wrapper rhs) { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - val = val * rhs.val; // cuda has no operator*= for __nv_bfloat -#else - val = cast_to_wrapper( - cast_from_wrapper(*this) * cast_from_wrapper(rhs), - val) - .val; -#endif - return *this; - } - - KOKKOS_FUNCTION - void operator*=(const volatile floating_point_wrapper& rhs) volatile { - floating_point_wrapper tmp_rhs = rhs; - floating_point_wrapper tmp_lhs = *this; - - tmp_lhs *= tmp_rhs; - *this = tmp_lhs; - } - - // Compund operators: upcast overloads for *= - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator*=(T& lhs, floating_point_wrapper rhs) { - lhs *= static_cast(rhs); - return lhs; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator*=(float rhs) { - float result = static_cast(val) * rhs; - val = static_cast(result); - return *this; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator*=(double rhs) { - double result = static_cast(val) * rhs; - val = static_cast(result); - return *this; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator/=(floating_point_wrapper rhs) { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - val = val / rhs.val; // cuda has no operator/= for __nv_bfloat -#else - val = cast_to_wrapper( - cast_from_wrapper(*this) / cast_from_wrapper(rhs), - val) - .val; -#endif - return *this; - } - - KOKKOS_FUNCTION - void operator/=(const volatile floating_point_wrapper& rhs) volatile { - floating_point_wrapper tmp_rhs = rhs; - floating_point_wrapper tmp_lhs = *this; - - tmp_lhs /= tmp_rhs; - *this = tmp_lhs; - } - - // Compund operators: upcast overloads for /= - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator/=(T& lhs, floating_point_wrapper rhs) { - lhs /= static_cast(rhs); - return lhs; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator/=(float rhs) { - float result = static_cast(val) / rhs; - val = static_cast(result); - return *this; - } - - KOKKOS_FUNCTION - floating_point_wrapper& operator/=(double rhs) { - double result = static_cast(val) / rhs; - val = static_cast(result); - return *this; - } - - // Binary Arithmetic - KOKKOS_FUNCTION - friend floating_point_wrapper operator+(floating_point_wrapper lhs, - floating_point_wrapper rhs) { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - lhs += rhs; -#else - lhs.val = cast_to_wrapper( - cast_from_wrapper(lhs) + cast_from_wrapper(rhs), - lhs.val) - .val; -#endif - return lhs; - } - - // Binary Arithmetic upcast operators for + - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator+(floating_point_wrapper lhs, T rhs) { - return T(lhs) + rhs; - } - - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator+(T lhs, floating_point_wrapper rhs) { - return lhs + T(rhs); - } - - KOKKOS_FUNCTION - friend floating_point_wrapper operator-(floating_point_wrapper lhs, - floating_point_wrapper rhs) { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - lhs -= rhs; -#else - lhs.val = cast_to_wrapper( - cast_from_wrapper(lhs) - cast_from_wrapper(rhs), - lhs.val) - .val; -#endif - return lhs; - } - - // Binary Arithmetic upcast operators for - - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator-(floating_point_wrapper lhs, T rhs) { - return T(lhs) - rhs; - } - - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator-(T lhs, floating_point_wrapper rhs) { - return lhs - T(rhs); - } - - KOKKOS_FUNCTION - friend floating_point_wrapper operator*(floating_point_wrapper lhs, - floating_point_wrapper rhs) { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - lhs *= rhs; -#else - lhs.val = cast_to_wrapper( - cast_from_wrapper(lhs) * cast_from_wrapper(rhs), - lhs.val) - .val; -#endif - return lhs; - } - - // Binary Arithmetic upcast operators for * - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator*(floating_point_wrapper lhs, T rhs) { - return T(lhs) * rhs; - } - - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator*(T lhs, floating_point_wrapper rhs) { - return lhs * T(rhs); - } - - KOKKOS_FUNCTION - friend floating_point_wrapper operator/(floating_point_wrapper lhs, - floating_point_wrapper rhs) { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - lhs /= rhs; -#else - lhs.val = cast_to_wrapper( - cast_from_wrapper(lhs) / cast_from_wrapper(rhs), - lhs.val) - .val; -#endif - return lhs; - } - - // Binary Arithmetic upcast operators for / - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator/(floating_point_wrapper lhs, T rhs) { - return T(lhs) / rhs; - } - - template - KOKKOS_FUNCTION friend std::enable_if_t< - std::is_same::value || std::is_same::value, T> - operator/(T lhs, floating_point_wrapper rhs) { - return lhs / T(rhs); - } - - // Logical operators - KOKKOS_FUNCTION - bool operator!() const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(!val); -#else - return !cast_from_wrapper(*this); -#endif - } - - // NOTE: Loses short-circuit evaluation - KOKKOS_FUNCTION - bool operator&&(floating_point_wrapper rhs) const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(val && rhs.val); -#else - return cast_from_wrapper(*this) && cast_from_wrapper(rhs); -#endif - } - - // NOTE: Loses short-circuit evaluation - KOKKOS_FUNCTION - bool operator||(floating_point_wrapper rhs) const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(val || rhs.val); -#else - return cast_from_wrapper(*this) || cast_from_wrapper(rhs); -#endif - } - - // Comparison operators - KOKKOS_FUNCTION - bool operator==(floating_point_wrapper rhs) const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(val == rhs.val); -#else - return cast_from_wrapper(*this) == cast_from_wrapper(rhs); -#endif - } - - KOKKOS_FUNCTION - bool operator!=(floating_point_wrapper rhs) const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(val != rhs.val); -#else - return cast_from_wrapper(*this) != cast_from_wrapper(rhs); -#endif - } - - KOKKOS_FUNCTION - bool operator<(floating_point_wrapper rhs) const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(val < rhs.val); -#else - return cast_from_wrapper(*this) < cast_from_wrapper(rhs); -#endif - } - - KOKKOS_FUNCTION - bool operator>(floating_point_wrapper rhs) const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(val > rhs.val); -#else - return cast_from_wrapper(*this) > cast_from_wrapper(rhs); -#endif - } - - KOKKOS_FUNCTION - bool operator<=(floating_point_wrapper rhs) const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(val <= rhs.val); -#else - return cast_from_wrapper(*this) <= cast_from_wrapper(rhs); -#endif - } - - KOKKOS_FUNCTION - bool operator>=(floating_point_wrapper rhs) const { -#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH - return static_cast(val >= rhs.val); -#else - return cast_from_wrapper(*this) >= cast_from_wrapper(rhs); -#endif - } - - KOKKOS_FUNCTION - friend bool operator==(const volatile floating_point_wrapper& lhs, - const volatile floating_point_wrapper& rhs) { - floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; - return tmp_lhs == tmp_rhs; - } - - KOKKOS_FUNCTION - friend bool operator!=(const volatile floating_point_wrapper& lhs, - const volatile floating_point_wrapper& rhs) { - floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; - return tmp_lhs != tmp_rhs; - } - - KOKKOS_FUNCTION - friend bool operator<(const volatile floating_point_wrapper& lhs, - const volatile floating_point_wrapper& rhs) { - floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; - return tmp_lhs < tmp_rhs; - } - - KOKKOS_FUNCTION - friend bool operator>(const volatile floating_point_wrapper& lhs, - const volatile floating_point_wrapper& rhs) { - floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; - return tmp_lhs > tmp_rhs; - } - - KOKKOS_FUNCTION - friend bool operator<=(const volatile floating_point_wrapper& lhs, - const volatile floating_point_wrapper& rhs) { - floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; - return tmp_lhs <= tmp_rhs; - } - - KOKKOS_FUNCTION - friend bool operator>=(const volatile floating_point_wrapper& lhs, - const volatile floating_point_wrapper& rhs) { - floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; - return tmp_lhs >= tmp_rhs; - } - - // Insertion and extraction operators - friend std::ostream& operator<<(std::ostream& os, - const floating_point_wrapper& x) { - const std::string out = std::to_string(static_cast(x)); - os << out; - return os; - } - - friend std::istream& operator>>(std::istream& is, floating_point_wrapper& x) { - std::string in; - is >> in; - x = std::stod(in); - return is; - } -}; -} // namespace Impl - -// Declare wrapper overloads now that floating_point_wrapper is declared -template -static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::half_t cast_to_wrapper( - T x, const volatile Kokkos::Impl::half_impl_t::type&) { - return Kokkos::Experimental::cast_to_half(x); -} - -#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED -template -static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::bhalf_t cast_to_wrapper( - T x, const volatile Kokkos::Impl::bhalf_impl_t::type&) { - return Kokkos::Experimental::cast_to_bhalf(x); -} -#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED - -template -static KOKKOS_INLINE_FUNCTION T -cast_from_wrapper(const Kokkos::Experimental::half_t& x) { - return Kokkos::Experimental::cast_from_half(x); -} - -#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED -template -static KOKKOS_INLINE_FUNCTION T -cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x) { - return Kokkos::Experimental::cast_from_bhalf(x); -} -#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED - -} // namespace Experimental -} // namespace Kokkos - -#endif // KOKKOS_IMPL_HALF_TYPE_DEFINED - -// If none of the above actually did anything and defined a half precision type -// define a fallback implementation here using float -#ifndef KOKKOS_IMPL_HALF_TYPE_DEFINED -#define KOKKOS_IMPL_HALF_TYPE_DEFINED -#define KOKKOS_HALF_T_IS_FLOAT true -namespace Kokkos { -namespace Impl { -struct half_impl_t { - using type = float; -}; -} // namespace Impl -namespace Experimental { - -using half_t = Kokkos::Impl::half_impl_t::type; - -// cast_to_half -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(float val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(bool val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(double val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(short val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(unsigned short val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(int val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(unsigned int val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(long val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(unsigned long val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(long long val) { return half_t(val); } -KOKKOS_INLINE_FUNCTION -half_t cast_to_half(unsigned long long val) { return half_t(val); } - -// cast_from_half -// Using an explicit list here too, since the other ones are explicit and for -// example don't include char -template -KOKKOS_INLINE_FUNCTION std::enable_if_t< - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, - T> -cast_from_half(half_t val) { - return T(val); -} - -} // namespace Experimental -} // namespace Kokkos - -#else -#define KOKKOS_HALF_T_IS_FLOAT false -#endif // KOKKOS_IMPL_HALF_TYPE_DEFINED - -#ifndef KOKKOS_IMPL_BHALF_TYPE_DEFINED -#define KOKKOS_IMPL_BHALF_TYPE_DEFINED -#define KOKKOS_BHALF_T_IS_FLOAT true -namespace Kokkos { -namespace Impl { -struct bhalf_impl_t { - using type = float; -}; -} // namespace Impl - -namespace Experimental { - -using bhalf_t = Kokkos::Impl::bhalf_impl_t::type; - -// cast_to_bhalf -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(float val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(bool val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(double val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(short val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(unsigned short val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(int val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(unsigned int val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(long val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(unsigned long val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(long long val) { return bhalf_t(val); } -KOKKOS_INLINE_FUNCTION -bhalf_t cast_to_bhalf(unsigned long long val) { return bhalf_t(val); } - -// cast_from_bhalf -template -KOKKOS_INLINE_FUNCTION std::enable_if_t< - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, - T> -cast_from_bhalf(bhalf_t val) { - return T(val); -} -} // namespace Experimental -} // namespace Kokkos -#else -#define KOKKOS_BHALF_T_IS_FLOAT false -#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED - +#include #include #ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF diff --git a/core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp b/core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp new file mode 100644 index 0000000000..7bf315de17 --- /dev/null +++ b/core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp @@ -0,0 +1,1016 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#ifndef KOKKOS_HALF_FLOATING_POINT_WRAPPER_HPP_ +#define KOKKOS_HALF_FLOATING_POINT_WRAPPER_HPP_ + +#include + +#include +#include // istream & ostream for extraction and insertion ops +#include + +#ifdef KOKKOS_IMPL_HALF_TYPE_DEFINED + +// KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH: A macro to select which +// floating_pointer_wrapper operator paths should be used. For CUDA, let the +// compiler conditionally select when device ops are used For SYCL, we have a +// full half type on both host and device +#if defined(__CUDA_ARCH__) || defined(KOKKOS_ENABLE_SYCL) +#define KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH +#endif + +/************************* BEGIN forward declarations *************************/ +namespace Kokkos { +namespace Experimental { +namespace Impl { +template +class floating_point_wrapper; +} + +// Declare half_t (binary16) +using half_t = Kokkos::Experimental::Impl::floating_point_wrapper< + Kokkos::Impl::half_impl_t ::type>; +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(float val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(bool val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(double val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(short val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(int val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(long val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(long long val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(unsigned short val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(unsigned int val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(unsigned long val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(unsigned long long val); +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(half_t); + +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION + std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION + std::enable_if_t::value, T> + cast_from_half(half_t); +template +KOKKOS_INLINE_FUNCTION + std::enable_if_t::value, T> + cast_from_half(half_t); + +// declare bhalf_t +#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED +using bhalf_t = Kokkos::Experimental::Impl::floating_point_wrapper< + Kokkos::Impl ::bhalf_impl_t ::type>; + +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(float val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(bool val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(double val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(short val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(int val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(long val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(long long val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(unsigned short val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(unsigned int val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(unsigned long val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(unsigned long long val); +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(bhalf_t val); + +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION + std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION + std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +template +KOKKOS_INLINE_FUNCTION + std::enable_if_t::value, T> + cast_from_bhalf(bhalf_t); +#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED + +template +static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::half_t cast_to_wrapper( + T x, const volatile Kokkos::Impl::half_impl_t::type&); + +#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED +template +static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::bhalf_t cast_to_wrapper( + T x, const volatile Kokkos::Impl::bhalf_impl_t::type&); +#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED + +template +static KOKKOS_INLINE_FUNCTION T +cast_from_wrapper(const Kokkos::Experimental::half_t& x); + +#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED +template +static KOKKOS_INLINE_FUNCTION T +cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x); +#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED +/************************** END forward declarations **************************/ + +namespace Impl { +template +class alignas(FloatType) floating_point_wrapper { + public: + using impl_type = FloatType; + + private: + impl_type val; + using fixed_width_integer_type = std::conditional_t< + sizeof(impl_type) == 2, uint16_t, + std::conditional_t< + sizeof(impl_type) == 4, uint32_t, + std::conditional_t>>; + static_assert(!std::is_void::value, + "Invalid impl_type"); + + public: + // In-class initialization and defaulted default constructors not used + // since Cuda supports half precision initialization via the below constructor + KOKKOS_FUNCTION + floating_point_wrapper() : val(0.0F) {} + +// Copy constructors +// Getting "C2580: multiple versions of a defaulted special +// member function are not allowed" with VS 16.11.3 and CUDA 11.4.2 +#if defined(_WIN32) && defined(KOKKOS_ENABLE_CUDA) + KOKKOS_FUNCTION + floating_point_wrapper(const floating_point_wrapper& rhs) : val(rhs.val) {} + + KOKKOS_FUNCTION + floating_point_wrapper& operator=(const floating_point_wrapper& rhs) { + val = rhs.val; + return *this; + } +#else + KOKKOS_DEFAULTED_FUNCTION + floating_point_wrapper(const floating_point_wrapper&) noexcept = default; + + KOKKOS_DEFAULTED_FUNCTION + floating_point_wrapper& operator=(const floating_point_wrapper&) noexcept = + default; +#endif + + KOKKOS_INLINE_FUNCTION + floating_point_wrapper(const volatile floating_point_wrapper& rhs) { +#if defined(KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH) && !defined(KOKKOS_ENABLE_SYCL) + val = rhs.val; +#else + const volatile fixed_width_integer_type* rv_ptr = + reinterpret_cast(&rhs.val); + const fixed_width_integer_type rv_val = *rv_ptr; + val = reinterpret_cast(rv_val); +#endif // KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + } + + // Don't support implicit conversion back to impl_type. + // impl_type is a storage only type on host. + KOKKOS_FUNCTION + explicit operator impl_type() const { return val; } + KOKKOS_FUNCTION + explicit operator float() const { return cast_from_wrapper(*this); } + KOKKOS_FUNCTION + explicit operator bool() const { return cast_from_wrapper(*this); } + KOKKOS_FUNCTION + explicit operator double() const { return cast_from_wrapper(*this); } + KOKKOS_FUNCTION + explicit operator short() const { return cast_from_wrapper(*this); } + KOKKOS_FUNCTION + explicit operator int() const { return cast_from_wrapper(*this); } + KOKKOS_FUNCTION + explicit operator long() const { return cast_from_wrapper(*this); } + KOKKOS_FUNCTION + explicit operator long long() const { + return cast_from_wrapper(*this); + } + KOKKOS_FUNCTION + explicit operator unsigned short() const { + return cast_from_wrapper(*this); + } + KOKKOS_FUNCTION + explicit operator unsigned int() const { + return cast_from_wrapper(*this); + } + KOKKOS_FUNCTION + explicit operator unsigned long() const { + return cast_from_wrapper(*this); + } + KOKKOS_FUNCTION + explicit operator unsigned long long() const { + return cast_from_wrapper(*this); + } + + /** + * Conversion constructors. + * + * Support implicit conversions from impl_type, float, double -> + * floating_point_wrapper. Mixed precision expressions require upcasting which + * is done in the + * "// Binary Arithmetic" operator overloads below. + * + * Support implicit conversions from integral types -> floating_point_wrapper. + * Expressions involving floating_point_wrapper with integral types require + * downcasting the integral types to floating_point_wrapper. Existing operator + * overloads can handle this with the addition of the below implicit + * conversion constructors. + */ + KOKKOS_FUNCTION + constexpr floating_point_wrapper(impl_type rhs) : val(rhs) {} + KOKKOS_FUNCTION + floating_point_wrapper(float rhs) : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(double rhs) : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + explicit floating_point_wrapper(bool rhs) + : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(short rhs) : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(int rhs) : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(long rhs) : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(long long rhs) : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(unsigned short rhs) + : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(unsigned int rhs) + : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(unsigned long rhs) + : val(cast_to_wrapper(rhs, val).val) {} + KOKKOS_FUNCTION + floating_point_wrapper(unsigned long long rhs) + : val(cast_to_wrapper(rhs, val).val) {} + + // Unary operators + KOKKOS_FUNCTION + floating_point_wrapper operator+() const { + floating_point_wrapper tmp = *this; +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + tmp.val = +tmp.val; +#else + tmp.val = cast_to_wrapper(+cast_from_wrapper(tmp), val).val; +#endif + return tmp; + } + + KOKKOS_FUNCTION + floating_point_wrapper operator-() const { + floating_point_wrapper tmp = *this; +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + tmp.val = -tmp.val; +#else + tmp.val = cast_to_wrapper(-cast_from_wrapper(tmp), val).val; +#endif + return tmp; + } + + // Prefix operators + KOKKOS_FUNCTION + floating_point_wrapper& operator++() { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + val = val + impl_type(1.0F); // cuda has no operator++ for __nv_bfloat +#else + float tmp = cast_from_wrapper(*this); + ++tmp; + val = cast_to_wrapper(tmp, val).val; +#endif + return *this; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator--() { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + val = val - impl_type(1.0F); // cuda has no operator-- for __nv_bfloat +#else + float tmp = cast_from_wrapper(*this); + --tmp; + val = cast_to_wrapper(tmp, val).val; +#endif + return *this; + } + + // Postfix operators + KOKKOS_FUNCTION + floating_point_wrapper operator++(int) { + floating_point_wrapper tmp = *this; + operator++(); + return tmp; + } + + KOKKOS_FUNCTION + floating_point_wrapper operator--(int) { + floating_point_wrapper tmp = *this; + operator--(); + return tmp; + } + + // Binary operators + KOKKOS_FUNCTION + floating_point_wrapper& operator=(impl_type rhs) { + val = rhs; + return *this; + } + + template + KOKKOS_FUNCTION floating_point_wrapper& operator=(T rhs) { + val = cast_to_wrapper(rhs, val).val; + return *this; + } + + template + KOKKOS_FUNCTION void operator=(T rhs) volatile { + impl_type new_val = cast_to_wrapper(rhs, val).val; + volatile fixed_width_integer_type* val_ptr = + reinterpret_cast( + const_cast(&val)); + *val_ptr = reinterpret_cast(new_val); + } + + // Compound operators + KOKKOS_FUNCTION + floating_point_wrapper& operator+=(floating_point_wrapper rhs) { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + val = val + rhs.val; // cuda has no operator+= for __nv_bfloat +#else + val = cast_to_wrapper( + cast_from_wrapper(*this) + cast_from_wrapper(rhs), + val) + .val; +#endif + return *this; + } + + KOKKOS_FUNCTION + void operator+=(const volatile floating_point_wrapper& rhs) volatile { + floating_point_wrapper tmp_rhs = rhs; + floating_point_wrapper tmp_lhs = *this; + + tmp_lhs += tmp_rhs; + *this = tmp_lhs; + } + + // Compound operators: upcast overloads for += + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator+=(T& lhs, floating_point_wrapper rhs) { + lhs += static_cast(rhs); + return lhs; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator+=(float rhs) { + float result = static_cast(val) + rhs; + val = static_cast(result); + return *this; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator+=(double rhs) { + double result = static_cast(val) + rhs; + val = static_cast(result); + return *this; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator-=(floating_point_wrapper rhs) { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + val = val - rhs.val; // cuda has no operator-= for __nv_bfloat +#else + val = cast_to_wrapper( + cast_from_wrapper(*this) - cast_from_wrapper(rhs), + val) + .val; +#endif + return *this; + } + + KOKKOS_FUNCTION + void operator-=(const volatile floating_point_wrapper& rhs) volatile { + floating_point_wrapper tmp_rhs = rhs; + floating_point_wrapper tmp_lhs = *this; + + tmp_lhs -= tmp_rhs; + *this = tmp_lhs; + } + + // Compund operators: upcast overloads for -= + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator-=(T& lhs, floating_point_wrapper rhs) { + lhs -= static_cast(rhs); + return lhs; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator-=(float rhs) { + float result = static_cast(val) - rhs; + val = static_cast(result); + return *this; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator-=(double rhs) { + double result = static_cast(val) - rhs; + val = static_cast(result); + return *this; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator*=(floating_point_wrapper rhs) { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + val = val * rhs.val; // cuda has no operator*= for __nv_bfloat +#else + val = cast_to_wrapper( + cast_from_wrapper(*this) * cast_from_wrapper(rhs), + val) + .val; +#endif + return *this; + } + + KOKKOS_FUNCTION + void operator*=(const volatile floating_point_wrapper& rhs) volatile { + floating_point_wrapper tmp_rhs = rhs; + floating_point_wrapper tmp_lhs = *this; + + tmp_lhs *= tmp_rhs; + *this = tmp_lhs; + } + + // Compund operators: upcast overloads for *= + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator*=(T& lhs, floating_point_wrapper rhs) { + lhs *= static_cast(rhs); + return lhs; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator*=(float rhs) { + float result = static_cast(val) * rhs; + val = static_cast(result); + return *this; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator*=(double rhs) { + double result = static_cast(val) * rhs; + val = static_cast(result); + return *this; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator/=(floating_point_wrapper rhs) { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + val = val / rhs.val; // cuda has no operator/= for __nv_bfloat +#else + val = cast_to_wrapper( + cast_from_wrapper(*this) / cast_from_wrapper(rhs), + val) + .val; +#endif + return *this; + } + + KOKKOS_FUNCTION + void operator/=(const volatile floating_point_wrapper& rhs) volatile { + floating_point_wrapper tmp_rhs = rhs; + floating_point_wrapper tmp_lhs = *this; + + tmp_lhs /= tmp_rhs; + *this = tmp_lhs; + } + + // Compund operators: upcast overloads for /= + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator/=(T& lhs, floating_point_wrapper rhs) { + lhs /= static_cast(rhs); + return lhs; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator/=(float rhs) { + float result = static_cast(val) / rhs; + val = static_cast(result); + return *this; + } + + KOKKOS_FUNCTION + floating_point_wrapper& operator/=(double rhs) { + double result = static_cast(val) / rhs; + val = static_cast(result); + return *this; + } + + // Binary Arithmetic + KOKKOS_FUNCTION + friend floating_point_wrapper operator+(floating_point_wrapper lhs, + floating_point_wrapper rhs) { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + lhs += rhs; +#else + lhs.val = cast_to_wrapper( + cast_from_wrapper(lhs) + cast_from_wrapper(rhs), + lhs.val) + .val; +#endif + return lhs; + } + + // Binary Arithmetic upcast operators for + + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator+(floating_point_wrapper lhs, T rhs) { + return T(lhs) + rhs; + } + + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator+(T lhs, floating_point_wrapper rhs) { + return lhs + T(rhs); + } + + KOKKOS_FUNCTION + friend floating_point_wrapper operator-(floating_point_wrapper lhs, + floating_point_wrapper rhs) { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + lhs -= rhs; +#else + lhs.val = cast_to_wrapper( + cast_from_wrapper(lhs) - cast_from_wrapper(rhs), + lhs.val) + .val; +#endif + return lhs; + } + + // Binary Arithmetic upcast operators for - + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator-(floating_point_wrapper lhs, T rhs) { + return T(lhs) - rhs; + } + + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator-(T lhs, floating_point_wrapper rhs) { + return lhs - T(rhs); + } + + KOKKOS_FUNCTION + friend floating_point_wrapper operator*(floating_point_wrapper lhs, + floating_point_wrapper rhs) { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + lhs *= rhs; +#else + lhs.val = cast_to_wrapper( + cast_from_wrapper(lhs) * cast_from_wrapper(rhs), + lhs.val) + .val; +#endif + return lhs; + } + + // Binary Arithmetic upcast operators for * + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator*(floating_point_wrapper lhs, T rhs) { + return T(lhs) * rhs; + } + + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator*(T lhs, floating_point_wrapper rhs) { + return lhs * T(rhs); + } + + KOKKOS_FUNCTION + friend floating_point_wrapper operator/(floating_point_wrapper lhs, + floating_point_wrapper rhs) { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + lhs /= rhs; +#else + lhs.val = cast_to_wrapper( + cast_from_wrapper(lhs) / cast_from_wrapper(rhs), + lhs.val) + .val; +#endif + return lhs; + } + + // Binary Arithmetic upcast operators for / + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator/(floating_point_wrapper lhs, T rhs) { + return T(lhs) / rhs; + } + + template + KOKKOS_FUNCTION friend std::enable_if_t< + std::is_same::value || std::is_same::value, T> + operator/(T lhs, floating_point_wrapper rhs) { + return lhs / T(rhs); + } + + // Logical operators + KOKKOS_FUNCTION + bool operator!() const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(!val); +#else + return !cast_from_wrapper(*this); +#endif + } + + // NOTE: Loses short-circuit evaluation + KOKKOS_FUNCTION + bool operator&&(floating_point_wrapper rhs) const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(val && rhs.val); +#else + return cast_from_wrapper(*this) && cast_from_wrapper(rhs); +#endif + } + + // NOTE: Loses short-circuit evaluation + KOKKOS_FUNCTION + bool operator||(floating_point_wrapper rhs) const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(val || rhs.val); +#else + return cast_from_wrapper(*this) || cast_from_wrapper(rhs); +#endif + } + + // Comparison operators + KOKKOS_FUNCTION + bool operator==(floating_point_wrapper rhs) const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(val == rhs.val); +#else + return cast_from_wrapper(*this) == cast_from_wrapper(rhs); +#endif + } + + KOKKOS_FUNCTION + bool operator!=(floating_point_wrapper rhs) const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(val != rhs.val); +#else + return cast_from_wrapper(*this) != cast_from_wrapper(rhs); +#endif + } + + KOKKOS_FUNCTION + bool operator<(floating_point_wrapper rhs) const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(val < rhs.val); +#else + return cast_from_wrapper(*this) < cast_from_wrapper(rhs); +#endif + } + + KOKKOS_FUNCTION + bool operator>(floating_point_wrapper rhs) const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(val > rhs.val); +#else + return cast_from_wrapper(*this) > cast_from_wrapper(rhs); +#endif + } + + KOKKOS_FUNCTION + bool operator<=(floating_point_wrapper rhs) const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(val <= rhs.val); +#else + return cast_from_wrapper(*this) <= cast_from_wrapper(rhs); +#endif + } + + KOKKOS_FUNCTION + bool operator>=(floating_point_wrapper rhs) const { +#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH + return static_cast(val >= rhs.val); +#else + return cast_from_wrapper(*this) >= cast_from_wrapper(rhs); +#endif + } + + KOKKOS_FUNCTION + friend bool operator==(const volatile floating_point_wrapper& lhs, + const volatile floating_point_wrapper& rhs) { + floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; + return tmp_lhs == tmp_rhs; + } + + KOKKOS_FUNCTION + friend bool operator!=(const volatile floating_point_wrapper& lhs, + const volatile floating_point_wrapper& rhs) { + floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; + return tmp_lhs != tmp_rhs; + } + + KOKKOS_FUNCTION + friend bool operator<(const volatile floating_point_wrapper& lhs, + const volatile floating_point_wrapper& rhs) { + floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; + return tmp_lhs < tmp_rhs; + } + + KOKKOS_FUNCTION + friend bool operator>(const volatile floating_point_wrapper& lhs, + const volatile floating_point_wrapper& rhs) { + floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; + return tmp_lhs > tmp_rhs; + } + + KOKKOS_FUNCTION + friend bool operator<=(const volatile floating_point_wrapper& lhs, + const volatile floating_point_wrapper& rhs) { + floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; + return tmp_lhs <= tmp_rhs; + } + + KOKKOS_FUNCTION + friend bool operator>=(const volatile floating_point_wrapper& lhs, + const volatile floating_point_wrapper& rhs) { + floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs; + return tmp_lhs >= tmp_rhs; + } + + // Insertion and extraction operators + friend std::ostream& operator<<(std::ostream& os, + const floating_point_wrapper& x) { + const std::string out = std::to_string(static_cast(x)); + os << out; + return os; + } + + friend std::istream& operator>>(std::istream& is, floating_point_wrapper& x) { + std::string in; + is >> in; + x = std::stod(in); + return is; + } +}; +} // namespace Impl + +// Declare wrapper overloads now that floating_point_wrapper is declared +template +static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::half_t cast_to_wrapper( + T x, const volatile Kokkos::Impl::half_impl_t::type&) { + return Kokkos::Experimental::cast_to_half(x); +} + +#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED +template +static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::bhalf_t cast_to_wrapper( + T x, const volatile Kokkos::Impl::bhalf_impl_t::type&) { + return Kokkos::Experimental::cast_to_bhalf(x); +} +#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED + +template +static KOKKOS_INLINE_FUNCTION T +cast_from_wrapper(const Kokkos::Experimental::half_t& x) { + return Kokkos::Experimental::cast_from_half(x); +} + +#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED +template +static KOKKOS_INLINE_FUNCTION T +cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x) { + return Kokkos::Experimental::cast_from_bhalf(x); +} +#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED + +} // namespace Experimental +} // namespace Kokkos + +#endif // KOKKOS_IMPL_HALF_TYPE_DEFINED + +// If none of the above actually did anything and defined a half precision type +// define a fallback implementation here using float +#ifndef KOKKOS_IMPL_HALF_TYPE_DEFINED +#define KOKKOS_IMPL_HALF_TYPE_DEFINED +#define KOKKOS_HALF_T_IS_FLOAT true +namespace Kokkos { +namespace Impl { +struct half_impl_t { + using type = float; +}; +} // namespace Impl +namespace Experimental { + +using half_t = Kokkos::Impl::half_impl_t::type; + +// cast_to_half +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(float val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(bool val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(double val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(short val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(unsigned short val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(int val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(unsigned int val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(long val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(unsigned long val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(long long val) { return half_t(val); } +KOKKOS_INLINE_FUNCTION +half_t cast_to_half(unsigned long long val) { return half_t(val); } + +// cast_from_half +// Using an explicit list here too, since the other ones are explicit and for +// example don't include char +template +KOKKOS_INLINE_FUNCTION std::enable_if_t< + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, + T> +cast_from_half(half_t val) { + return T(val); +} + +} // namespace Experimental +} // namespace Kokkos + +#else +#define KOKKOS_HALF_T_IS_FLOAT false +#endif // KOKKOS_IMPL_HALF_TYPE_DEFINED + +#ifndef KOKKOS_IMPL_BHALF_TYPE_DEFINED +#define KOKKOS_IMPL_BHALF_TYPE_DEFINED +#define KOKKOS_BHALF_T_IS_FLOAT true +namespace Kokkos { +namespace Impl { +struct bhalf_impl_t { + using type = float; +}; +} // namespace Impl + +namespace Experimental { + +using bhalf_t = Kokkos::Impl::bhalf_impl_t::type; + +// cast_to_bhalf +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(float val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(bool val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(double val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(short val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(unsigned short val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(int val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(unsigned int val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(long val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(unsigned long val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(long long val) { return bhalf_t(val); } +KOKKOS_INLINE_FUNCTION +bhalf_t cast_to_bhalf(unsigned long long val) { return bhalf_t(val); } + +// cast_from_bhalf +template +KOKKOS_INLINE_FUNCTION std::enable_if_t< + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, + T> +cast_from_bhalf(bhalf_t val) { + return T(val); +} +} // namespace Experimental +} // namespace Kokkos +#else +#define KOKKOS_BHALF_T_IS_FLOAT false +#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED + +#endif // KOKKOS_HALF_FLOATING_POINT_WRAPPER_HPP_