Skip to content

Commit

Permalink
Implement rot{l,r} function templates
Browse files Browse the repository at this point in the history
  • Loading branch information
dalg24 committed Feb 23, 2023
1 parent fb3d754 commit 22ee14e
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 0 deletions.
70 changes: 70 additions & 0 deletions core/src/Kokkos_BitManipulation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,32 @@ bit_width(T x) noexcept {
}
//</editor-fold>

//<editor-fold desc="[bit.rotate], rotating">
template <class T>
[[nodiscard]] KOKKOS_FUNCTION constexpr std::enable_if_t<
Impl::is_standard_unsigned_integer_type_v<T>, T>
rotl(T x, int s) noexcept {
using Experimental::digits_v;
constexpr auto dig = digits_v<T>;
int const rem = s % dig;
if (rem == 0) return x;
if (rem > 0) return (x << rem) | (x >> ((dig - rem) % dig));
return (x >> -rem) | (x << ((dig + rem) % dig)); // rotr(x, -rem)
}

template <class T>
[[nodiscard]] KOKKOS_FUNCTION constexpr std::enable_if_t<
Impl::is_standard_unsigned_integer_type_v<T>, T>
rotr(T x, int s) noexcept {
using Experimental::digits_v;
constexpr auto dig = digits_v<T>;
int const rem = s % dig;
if (rem == 0) return x;
if (rem > 0) return (x >> rem) | (x << ((dig - rem) % dig));
return (x << -rem) | (x >> ((dig + rem) % dig)); // rotl(x, -rem)
}
//</editor-fold>

} // namespace Kokkos

namespace Kokkos::Impl {
Expand Down Expand Up @@ -278,6 +304,34 @@ KOKKOS_IMPL_HOST_FUNCTION

#undef KOKKOS_IMPL_USE_GCC_BUILT_IN_FUNCTIONS

template <class T>
KOKKOS_FUNCTION T rotl_builtin_host(T x, int s) noexcept {
return rotl(x, s);
}

template <class T>
KOKKOS_FUNCTION T rotl_builtin_device(T x, int s) noexcept {
#ifdef KOKKOS_ENABLE_SYCL
return sycl::rotate(x, s);
#else
return rotl(x, s);
#endif
}

template <class T>
KOKKOS_FUNCTION T rotr_builtin_host(T x, int s) noexcept {
return rotr(x, s);
}

template <class T>
KOKKOS_FUNCTION T rotr_builtin_device(T x, int s) noexcept {
#ifdef KOKKOS_ENABLE_SYCL
return sycl::rotate(x, -s);
#else
return rotr(x, s);
#endif
}

} // namespace Kokkos::Impl

namespace Kokkos::Experimental {
Expand Down Expand Up @@ -353,6 +407,22 @@ KOKKOS_FUNCTION
return digits_v<T> - countl_zero_builtin(x);
}

template <class T>
[[nodiscard]] KOKKOS_FUNCTION
std::enable_if_t<::Kokkos::Impl::is_standard_unsigned_integer_type_v<T>, T>
rotl_builtin(T x, int s) noexcept {
KOKKOS_IF_ON_DEVICE((return ::Kokkos::Impl::rotl_builtin_device(x, s);))
KOKKOS_IF_ON_HOST((return ::Kokkos::Impl::rotl_builtin_host(x, s);))
}

template <class T>
[[nodiscard]] KOKKOS_FUNCTION
std::enable_if_t<::Kokkos::Impl::is_standard_unsigned_integer_type_v<T>, T>
rotr_builtin(T x, int s) noexcept {
KOKKOS_IF_ON_DEVICE((return ::Kokkos::Impl::rotr_builtin_device(x, s);))
KOKKOS_IF_ON_HOST((return ::Kokkos::Impl::rotr_builtin_host(x, s);))
}

} // namespace Kokkos::Experimental

#endif
99 changes: 99 additions & 0 deletions core/unit_test/TestBitManipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,105 @@ struct X {
static_assert(test_##FUNC((float)0).did_not_match()); \
static_assert(test_##FUNC((void*)0).did_not_match())

//<editor-fold desc="[bit.rotate]">
template <class UInt>
constexpr auto test_rotl(UInt x) -> decltype(Kokkos::rotl(x, 0)) {
using Kokkos::rotl;

static_assert(noexcept(rotl(x, 0)));
static_assert(std::is_same_v<decltype(rotl(x, 0)), UInt>);

constexpr auto dig = Kokkos::Experimental::digits_v<UInt>;
constexpr auto max = Kokkos::Experimental::finite_max_v<UInt>;

static_assert(rotl(UInt(0), 0) == 0);
static_assert(rotl(UInt(0), 1) == 0);
static_assert(rotl(UInt(0), 4) == 0);
static_assert(rotl(UInt(0), 8) == 0);
static_assert(rotl(max, 0) == max);
static_assert(rotl(max, 1) == max);
static_assert(rotl(max, 4) == max);
static_assert(rotl(max, 8) == max);
static_assert(rotl(UInt(1), 0) == UInt(1) << 0);
static_assert(rotl(UInt(1), 1) == UInt(1) << 1);
static_assert(rotl(UInt(1), 4) == UInt(1) << 4);
static_assert(rotl(UInt(1), dig) == UInt(1));
static_assert(rotl(UInt(7), dig) == UInt(7));
static_assert(rotl(UInt(6), dig - 1) == UInt(3));
static_assert(rotl(UInt(3), 6) == UInt(3) << 6);

static_assert(rotl(UInt(max - 1), 0) == UInt(max - 1));
static_assert(rotl(UInt(max - 1), 1) == UInt(max - 2));
static_assert(rotl(UInt(max - 1), 2) == UInt(max - 4));
static_assert(rotl(UInt(max - 1), 3) == UInt(max - 8));
static_assert(rotl(UInt(max - 1), 4) == UInt(max - 16));
static_assert(rotl(UInt(max - 1), 5) == UInt(max - 32));
static_assert(rotl(UInt(max - 1), 6) == UInt(max - 64));
static_assert(rotl(UInt(max - 1), 7) == UInt(max - 128));
static_assert(rotl(UInt(1), 0) == UInt(1));
static_assert(rotl(UInt(1), 1) == UInt(2));
static_assert(rotl(UInt(1), 2) == UInt(4));
static_assert(rotl(UInt(1), 3) == UInt(8));
static_assert(rotl(UInt(1), 4) == UInt(16));
static_assert(rotl(UInt(1), 5) == UInt(32));
static_assert(rotl(UInt(1), 6) == UInt(64));
static_assert(rotl(UInt(1), 7) == UInt(128));

return true;
}

TEST_BIT_MANIPULATION(rotl);

template <class UInt>
constexpr auto test_rotr(UInt x) -> decltype(Kokkos::rotr(x, 0)) {
using Kokkos::rotr;

static_assert(noexcept(rotr(x, 0)));
static_assert(std::is_same_v<decltype(rotr(x, 0)), UInt>);

constexpr auto dig = Kokkos::Experimental::digits_v<UInt>;
constexpr auto max = Kokkos::Experimental::finite_max_v<UInt>;
constexpr auto highbit = rotr(UInt(1), 1);

static_assert(rotr(UInt(0), 0) == 0);
static_assert(rotr(UInt(0), 1) == 0);
static_assert(rotr(UInt(0), 4) == 0);
static_assert(rotr(UInt(0), 8) == 0);
static_assert(rotr(max, 0) == max);
static_assert(rotr(max, 1) == max);
static_assert(rotr(max, 4) == max);
static_assert(rotr(max, 8) == max);
static_assert(rotr(UInt(128), 0) == UInt(128) >> 0);
static_assert(rotr(UInt(128), 1) == UInt(128) >> 1);
static_assert(rotr(UInt(128), 4) == UInt(128) >> 4);
static_assert(rotr(UInt(1), dig) == UInt(1));
static_assert(rotr(UInt(7), dig) == UInt(7));
static_assert(rotr(UInt(6), dig - 1) == UInt(12));
static_assert(rotr(UInt(36), dig - 2) == UInt(144));

static_assert(rotr(UInt(max - 1), 0) == UInt(max - 1));
static_assert(rotr(UInt(max - 1), 1) == UInt(max - highbit));
static_assert(rotr(UInt(max - 1), 2) == UInt(max - (highbit >> 1)));
static_assert(rotr(UInt(max - 1), 3) == UInt(max - (highbit >> 2)));
static_assert(rotr(UInt(max - 1), 4) == UInt(max - (highbit >> 3)));
static_assert(rotr(UInt(max - 1), 5) == UInt(max - (highbit >> 4)));
static_assert(rotr(UInt(max - 1), 6) == UInt(max - (highbit >> 5)));
static_assert(rotr(UInt(max - 1), 7) == UInt(max - (highbit >> 6)));
static_assert(rotr(UInt(128), 0) == UInt(128));
static_assert(rotr(UInt(128), 1) == UInt(64));
static_assert(rotr(UInt(128), 2) == UInt(32));
static_assert(rotr(UInt(128), 3) == UInt(16));
static_assert(rotr(UInt(128), 4) == UInt(8));
static_assert(rotr(UInt(128), 5) == UInt(4));
static_assert(rotr(UInt(128), 6) == UInt(2));
static_assert(rotr(UInt(128), 7) == UInt(1));

return true;
}

TEST_BIT_MANIPULATION(rotr);
//</editor-fold>

//<editor-fold desc="[bit.count]">
template <class UInt>
constexpr auto test_countl_zero(UInt x) -> decltype(Kokkos::countl_zero(x)) {
Expand Down
173 changes: 173 additions & 0 deletions core/unit_test/TestBitManipulationBuiltins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ DEFINE_BIT_MANIPULATION_FUNCTION_EVAL(bit_ceil);
DEFINE_BIT_MANIPULATION_FUNCTION_EVAL(bit_floor);
DEFINE_BIT_MANIPULATION_FUNCTION_EVAL(bit_width);

#undef DEFINE_BIT_MANIPULATION_FUNCTION_EVAL

template <class Space, class Func, class Arg, std::size_t N>
struct TestBitManipFunction {
Arg val_[N];
Expand Down Expand Up @@ -425,3 +427,174 @@ TEST(TEST_CATEGORY, bit_manip_bit_width) {
test_bit_manip_bit_width<unsigned long>();
test_bit_manip_bit_width<unsigned long long>();
}

#undef TEST_BIT_MANIP_FUNCTION

#define DEFINE_BIT_ROTATE_FUNCTION_EVAL(FUNC) \
struct BitRotateFunction_##FUNC { \
template <class T> \
static KOKKOS_FUNCTION auto eval_constexpr(T x, int s) { \
return Kokkos::FUNC(x, s); \
} \
template <class T> \
static KOKKOS_FUNCTION auto eval_builtin(T x, int s) { \
return Kokkos::Experimental::FUNC##_builtin(x, s); \
} \
static char const* name() { return #FUNC; } \
}

DEFINE_BIT_ROTATE_FUNCTION_EVAL(rotl);
DEFINE_BIT_ROTATE_FUNCTION_EVAL(rotr);

#undef DEFINE_BIT_ROTATE_FUNCTION_EVAL

template <class T>
struct P {
using type = T;
T x;
int s;
};

template <class Space, class Func, class Arg, std::size_t N>
struct TestBitRotateFunction {
Arg val_[N];
TestBitRotateFunction(const Arg (&val)[N]) {
std::copy(val, val + N, val_);
run();
}
void run() const {
int errors = 0;
Kokkos::parallel_reduce(Kokkos::RangePolicy<Space>(0, N), *this, errors);
ASSERT_EQ(errors, 0) << "Failed check no error for " << Func::name() << "("
<< type_helper<typename Arg::type>::name() << ", int)";
}
KOKKOS_FUNCTION void operator()(int i, int& e) const {
if (Func::eval_builtin(val_[i].x, val_[i].s) !=
Func::eval_constexpr(val_[i].x, val_[i].s)) {
++e;
KOKKOS_IMPL_DO_NOT_USE_PRINTF(
"value at %x rotated by %d which is %x was expected to be %x\n",
(unsigned)val_[i].x, val_[i].s,
(unsigned)Func::eval_builtin(val_[i].x, val_[i].s),
(unsigned)Func::eval_constexpr(val_[i].x, val_[i].s));
}
}
};

template <class Space, class... Func, class Arg, std::size_t N>
void do_test_bit_rotate_function(const Arg (&x)[N]) {
(void)std::initializer_list<int>{
(TestBitRotateFunction<Space, Func, Arg, N>(x), 0)...};
}

#define TEST_BIT_ROTATE_FUNCTION(FUNC) \
do_test_bit_rotate_function<TEST_EXECSPACE, BitRotateFunction_##FUNC>

template <class UInt>
void test_bit_manip_rotl() {
using Kokkos::Experimental::rotl_builtin;
static_assert(noexcept(rotl_builtin(UInt(), 0)));
static_assert(std::is_same_v<decltype(rotl_builtin(UInt(), 0)), UInt>);
constexpr auto dig = Kokkos::Experimental::digits_v<UInt>;
constexpr auto max = Kokkos::Experimental::finite_max_v<UInt>;
TEST_BIT_ROTATE_FUNCTION(rotl)
({
// clang-format off
P<UInt>{UInt(0), 0},
P<UInt>{UInt(0), 1},
P<UInt>{UInt(0), 4},
P<UInt>{UInt(0), 8},
P<UInt>{max, 0},
P<UInt>{max, 1},
P<UInt>{max, 4},
P<UInt>{max, 8},
P<UInt>{UInt(1), 0},
P<UInt>{UInt(1), 1},
P<UInt>{UInt(1), 4},
P<UInt>{UInt(1), dig},
P<UInt>{UInt(7), dig},
P<UInt>{UInt(6), dig - 1},
P<UInt>{UInt(3), 6},
P<UInt>{UInt(max - 1), 0},
P<UInt>{UInt(max - 1), 1},
P<UInt>{UInt(max - 1), 2},
P<UInt>{UInt(max - 1), 3},
P<UInt>{UInt(max - 1), 4},
P<UInt>{UInt(max - 1), 5},
P<UInt>{UInt(max - 1), 6},
P<UInt>{UInt(max - 1), 7},
P<UInt>{UInt(1), 0},
P<UInt>{UInt(1), 1},
P<UInt>{UInt(1), 2},
P<UInt>{UInt(1), 3},
P<UInt>{UInt(1), 4},
P<UInt>{UInt(1), 5},
P<UInt>{UInt(1), 6},
P<UInt>{UInt(1), 7},
// clang-format on
});
}

TEST(TEST_CATEGORY, bit_manip_rotl) {
test_bit_manip_rotl<unsigned char>();
test_bit_manip_rotl<unsigned short>();
test_bit_manip_rotl<unsigned int>();
test_bit_manip_rotl<unsigned long>();
test_bit_manip_rotl<unsigned long long>();
}

template <class UInt>
void test_bit_manip_rotr() {
using Kokkos::rotr;
using Kokkos::Experimental::rotr_builtin;
static_assert(noexcept(rotr_builtin(UInt(), 0)));
static_assert(std::is_same_v<decltype(rotr_builtin(UInt(), 0)), UInt>);
constexpr auto dig = Kokkos::Experimental::digits_v<UInt>;
constexpr auto max = Kokkos::Experimental::finite_max_v<UInt>;
TEST_BIT_ROTATE_FUNCTION(rotr)
({
// clang-format off
P<UInt>{UInt(0), 0},
P<UInt>{UInt(0), 1},
P<UInt>{UInt(0), 4},
P<UInt>{UInt(0), 8},
P<UInt>{max, 0},
P<UInt>{max, 1},
P<UInt>{max, 4},
P<UInt>{max, 8},
P<UInt>{UInt(128), 0},
P<UInt>{UInt(128), 1},
P<UInt>{UInt(128), 4},
P<UInt>{UInt(1), dig},
P<UInt>{UInt(7), dig},
P<UInt>{UInt(6), dig - 1},
P<UInt>{UInt(36), dig - 2},
P<UInt>{UInt(max - 1), 0},
P<UInt>{UInt(max - 1), 1},
P<UInt>{UInt(max - 1), 2},
P<UInt>{UInt(max - 1), 3},
P<UInt>{UInt(max - 1), 4},
P<UInt>{UInt(max - 1), 5},
P<UInt>{UInt(max - 1), 6},
P<UInt>{UInt(max - 1), 7},
P<UInt>{UInt(128), 0},
P<UInt>{UInt(128), 1},
P<UInt>{UInt(128), 2},
P<UInt>{UInt(128), 3},
P<UInt>{UInt(128), 4},
P<UInt>{UInt(128), 5},
P<UInt>{UInt(128), 6},
P<UInt>{UInt(128), 0},
// clang-format on
});
}

TEST(TEST_CATEGORY, bit_manip_rotr) {
test_bit_manip_rotr<unsigned char>();
test_bit_manip_rotr<unsigned short>();
test_bit_manip_rotr<unsigned int>();
test_bit_manip_rotr<unsigned long>();
test_bit_manip_rotr<unsigned long long>();
}

#undef TEST_BIT_ROTATE_FUNCTION

0 comments on commit 22ee14e

Please sign in to comment.