Skip to content

Commit

Permalink
Merge pull request #6228 from masterleinad/cherry_pick_6223
Browse files Browse the repository at this point in the history
[4.1.00] Fix SIMD support on GPUs
  • Loading branch information
dalg24 committed Jun 20, 2023
2 parents 5c3e683 + dd81ecb commit 9e84430
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 113 deletions.
21 changes: 16 additions & 5 deletions cmake/kokkos_arch.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,10 @@ IF (KOKKOS_ARCH_SKL)
ENDIF()

IF (KOKKOS_ARCH_SKX)
#avx512-xeon
SET(KOKKOS_ARCH_AVX512XEON ON)
# FIXME_NVHPC nvc++ doesn't seem to support AVX512.
IF (NOT KOKKOS_CXX_HOST_COMPILER_ID STREQUAL NVHPC)
SET(KOKKOS_ARCH_AVX512XEON ON)
ENDIF()
COMPILER_SPECIFIC_FLAGS(
COMPILER_ID KOKKOS_CXX_HOST_COMPILER_ID
Cray NO-VALUE-SPECIFIED
Expand All @@ -419,7 +421,10 @@ IF (KOKKOS_ARCH_SKX)
ENDIF()

IF (KOKKOS_ARCH_ICL)
SET(KOKKOS_ARCH_AVX512XEON ON)
# FIXME_NVHPC nvc++ doesn't seem to support AVX512.
IF (NOT KOKKOS_CXX_HOST_COMPILER_ID STREQUAL NVHPC)
SET(KOKKOS_ARCH_AVX512XEON ON)
ENDIF()
COMPILER_SPECIFIC_FLAGS(
COMPILER_ID KOKKOS_CXX_HOST_COMPILER_ID
MSVC /arch:AVX512
Expand All @@ -428,7 +433,10 @@ IF (KOKKOS_ARCH_ICL)
ENDIF()

IF (KOKKOS_ARCH_ICX)
SET(KOKKOS_ARCH_AVX512XEON ON)
# FIXME_NVHPC nvc++ doesn't seem to support AVX512.
IF (NOT KOKKOS_CXX_HOST_COMPILER_ID STREQUAL NVHPC)
SET(KOKKOS_ARCH_AVX512XEON ON)
ENDIF()
COMPILER_SPECIFIC_FLAGS(
COMPILER_ID KOKKOS_CXX_HOST_COMPILER_ID
MSVC /arch:AVX512
Expand All @@ -437,7 +445,10 @@ IF (KOKKOS_ARCH_ICX)
ENDIF()

IF (KOKKOS_ARCH_SPR)
SET(KOKKOS_ARCH_AVX512XEON ON)
# FIXME_NVHPC nvc++ doesn't seem to support AVX512.
IF (NOT KOKKOS_CXX_HOST_COMPILER_ID STREQUAL NVHPC)
SET(KOKKOS_ARCH_AVX512XEON ON)
ENDIF()
COMPILER_SPECIFIC_FLAGS(
COMPILER_ID KOKKOS_CXX_HOST_COMPILER_ID
MSVC /arch:AVX512
Expand Down
54 changes: 35 additions & 19 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ class simd<std::int32_t, simd_abi::avx2_fixed_size<4>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen)
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen)
: m_value(_mm_setr_epi32(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
gen(std::integral_constant<std::size_t, 2>()),
Expand Down Expand Up @@ -700,7 +700,7 @@ class simd<std::int64_t, simd_abi::avx2_fixed_size<4>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen)
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen)
: m_value(_mm256_setr_epi64x(
gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
Expand Down Expand Up @@ -822,7 +822,7 @@ class simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen)
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen)
: m_value(_mm256_setr_epi64x(
gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
Expand Down Expand Up @@ -958,11 +958,15 @@ class const_where_expression<simd_mask<double, simd_abi::avx2_fixed_size<4>>,
}
}

friend constexpr auto const& Impl::mask<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1026,11 +1030,15 @@ class const_where_expression<
static_cast<__m128i>(m_value));
}

friend constexpr auto const& Impl::mask<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1088,11 +1096,15 @@ class const_where_expression<
static_cast<__m256i>(m_value));
}

friend constexpr auto const& Impl::mask<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1152,11 +1164,15 @@ class const_where_expression<
static_cast<__m256i>(m_value));
}

friend constexpr auto const& Impl::mask<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down
78 changes: 49 additions & 29 deletions simd/src/Kokkos_SIMD_AVX512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class simd<std::int32_t, simd_abi::avx512_fixed_size<8>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen)
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen)
: m_value(
_mm256_setr_epi32(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
Expand Down Expand Up @@ -854,11 +854,15 @@ class const_where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
static_cast<__m512d>(m_value), 8);
}

friend constexpr auto const& Impl::mask<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -922,11 +926,15 @@ class const_where_expression<
static_cast<__m256i>(m_value));
}

friend constexpr auto const& Impl::mask<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -984,11 +992,15 @@ class const_where_expression<
static_cast<__m256i>(m_value));
}

friend constexpr auto const& Impl::mask<std::uint32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::uint32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1046,11 +1058,15 @@ class const_where_expression<
static_cast<__m512i>(m_value));
}

friend constexpr auto const& Impl::mask<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1108,11 +1124,15 @@ class const_where_expression<
static_cast<__m512i>(m_value));
}

friend constexpr auto const& Impl::mask<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1152,34 +1172,34 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx512_fixed_size<8>>,
simd_mask<std::int32_t, simd_abi::avx512_fixed_size<8>>,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_max_epi32(
static_cast<__mmask8>(Impl::mask(x)),
_mm512_castsi256_si512(static_cast<__m256i>(Impl::value(x))));
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double hmin(
const_where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
simd<double, simd_abi::avx512_fixed_size<8>>> const&
x) {
return _mm512_mask_reduce_min_pd(static_cast<__mmask8>(Impl::mask(x)),
static_cast<__m512d>(Impl::value(x)));
return _mm512_mask_reduce_min_pd(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512d>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t reduce(
const_where_expression<
simd_mask<std::int64_t, simd_abi::avx512_fixed_size<8>>,
simd<std::int64_t, simd_abi::avx512_fixed_size<8>>> const& x,
std::int64_t, std::plus<>) {
return _mm512_mask_reduce_add_epi64(static_cast<__mmask8>(Impl::mask(x)),
static_cast<__m512i>(Impl::value(x)));
return _mm512_mask_reduce_add_epi64(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512i>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double reduce(
const_where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
simd<double, simd_abi::avx512_fixed_size<8>>> const&
x,
double, std::plus<>) {
return _mm512_mask_reduce_add_pd(static_cast<__mmask8>(Impl::mask(x)),
static_cast<__m512d>(Impl::value(x)));
return _mm512_mask_reduce_add_pd(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512d>(x.impl_get_value()));
}

} // namespace Experimental
Expand Down
38 changes: 11 additions & 27 deletions simd/src/Kokkos_SIMD_Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ class where_expression<bool, T> : public const_where_expression<bool, T> {
};

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
where_expression<simd_mask<T, Abi>, simd<T, Abi>>
where(typename simd<T, Abi>::mask_type const& mask, simd<T, Abi>& value) {
return where_expression(mask, value);
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
const_where_expression<simd_mask<T, Abi>, simd<T, Abi>>
where(typename simd<T, Abi>::mask_type const& mask,
simd<T, Abi> const& value) {
Expand Down Expand Up @@ -308,44 +308,28 @@ KOKKOS_FORCEINLINE_FUNCTION where_expression<M, T>& operator/=(
// fallback implementations of reductions across simd_mask:

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool all_of(
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool all_of(
simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(true);
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool any_of(
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool any_of(
simd_mask<T, Abi> const& a) {
return a != simd_mask<T, Abi>(false);
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool none_of(
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool none_of(
simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(false);
}

namespace Impl {

template <typename T, typename Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr auto const& mask(
const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
return x.m_mask;
}

template <typename T, typename Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr auto const& value(
const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
return x.m_value;
}

} // namespace Impl

template <typename T, typename Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
hmin(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
auto const& v = Impl::value(x);
auto const& m = Impl::mask(x);
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = Kokkos::reduction_identity<T>::min();
for (std::size_t i = 0; i < v.size(); ++i) {
if (m[i]) result = Kokkos::min(result, v[i]);
Expand All @@ -356,8 +340,8 @@ hmin(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
hmax(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
auto const& v = Impl::value(x);
auto const& m = Impl::mask(x);
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = Kokkos::reduction_identity<T>::max();
for (std::size_t i = 0; i < v.size(); ++i) {
if (m[i]) result = Kokkos::max(result, v[i]);
Expand All @@ -369,8 +353,8 @@ template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
reduce(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x, T,
std::plus<>) {
auto const& v = Impl::value(x);
auto const& m = Impl::mask(x);
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = Kokkos::reduction_identity<T>::sum();
for (std::size_t i = 0; i < v.size(); ++i) {
if (m[i]) result += v[i];
Expand Down
Loading

0 comments on commit 9e84430

Please sign in to comment.