Skip to content

Commit

Permalink
Add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Feb 28, 2023
1 parent f93e48a commit db890c9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct NonTrivialReduceFunctor {
NonTrivialReduceFunctor(NonTrivialReduceFunctor &&) = default;
NonTrivialReduceFunctor &operator=(NonTrivialReduceFunctor &&) = default;
NonTrivialReduceFunctor &operator=(NonTrivialReduceFunctor const &) = default;
KOKKOS_FUNCTION ~NonTrivialReduceFunctor() {}
~NonTrivialReduceFunctor() {}
};

template <class ExecSpace>
Expand Down
49 changes: 39 additions & 10 deletions core/unit_test/incremental/Test16_ParallelScan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,45 @@ namespace Test {
using value_type = double;
const int N = 10;

template <typename ExecSpace>
struct TrivialScanFunctor {
Kokkos::View<value_type *, ExecSpace> d_data;

KOKKOS_FUNCTION
void operator()(const int i, value_type &update_value,
const bool final) const {
const value_type val_i = d_data(i);
if (final) d_data(i) = update_value;
update_value += val_i;
}
};

template <typename ExecSpace>
struct NonTrivialScanFunctor {
Kokkos::View<value_type *, ExecSpace> d_data;

KOKKOS_FUNCTION
void operator()(const int i, value_type &update_value,
const bool final) const {
const value_type val_i = d_data(i);
if (final) d_data(i) = update_value;
update_value += val_i;
}

NonTrivialScanFunctor() = default;
NonTrivialScanFunctor(NonTrivialScanFunctor const &) = default;
NonTrivialScanFunctor(NonTrivialScanFunctor &&) = default;
NonTrivialScanFunctor &operator=(NonTrivialScanFunctor &&) = default;
NonTrivialScanFunctor &operator=(NonTrivialScanFunctor const &) = default;
~NonTrivialScanFunctor() {}
};

template <class ExecSpace>
struct TestScan {
// 1D View of double
using View_1D = typename Kokkos::View<value_type *, ExecSpace>;

template <typename FunctorType>
void parallel_scan() {
View_1D d_data("data", N);

Expand All @@ -39,15 +73,9 @@ struct TestScan {
Kokkos::RangePolicy<ExecSpace>(0, N),
KOKKOS_LAMBDA(const int i) { d_data(i) = i * 0.5; });

// Exclusive parallel_scan call.
Kokkos::parallel_scan(
Kokkos::RangePolicy<ExecSpace>(0, N),
KOKKOS_LAMBDA(const int i, value_type &update_value, const bool final) {
const value_type val_i = d_data(i);
if (final) d_data(i) = update_value;

update_value += val_i;
});
// Exclusive parallel_scan call
Kokkos::parallel_scan(Kokkos::RangePolicy<ExecSpace>(0, N),
FunctorType{d_data});

// Copy back the data.
auto h_data =
Expand All @@ -65,7 +93,8 @@ struct TestScan {

TEST(TEST_CATEGORY, IncrTest_16_parallelscan) {
TestScan<TEST_EXECSPACE> test;
test.parallel_scan();
test.parallel_scan<TrivialScanFunctor<TEST_EXECSPACE>>();
test.parallel_scan<NonTrivialScanFunctor<TEST_EXECSPACE>>();
}

} // namespace Test

0 comments on commit db890c9

Please sign in to comment.