diff --git a/core/src/HPX/Kokkos_HPX.hpp b/core/src/HPX/Kokkos_HPX.hpp index c40392265f..06eed664df 100644 --- a/core/src/HPX/Kokkos_HPX.hpp +++ b/core/src/HPX/Kokkos_HPX.hpp @@ -1110,19 +1110,17 @@ class ParallelReduce; + using iterate_type = typename Kokkos::Impl::HostIterateTile< + MDRangePolicy, CombinedFunctorReducerType, WorkTag, reference_type>; const iterate_type m_iter; const Policy m_policy; - const CombinedFunctorReducerType m_functor_reducer; const pointer_type m_result_ptr; const bool m_force_synchronous; public: void setup() const { - const ReducerType &reducer = m_functor_reducer.get_reducer(); + const ReducerType &reducer = m_iter.m_func.get_reducer(); const std::size_t value_size = reducer.value_size(); const int num_worker_threads = m_policy.space().concurrency(); @@ -1148,7 +1146,7 @@ class ParallelReduce(buffer.get(0)), @@ -1180,9 +1178,8 @@ class ParallelReduce inline ParallelReduce(const CombinedFunctorReducerType &arg_functor_reducer, MDRangePolicy arg_policy, const ViewType &arg_view) - : m_iter(arg_policy, arg_functor_reducer.get_functor()), + : m_iter(arg_policy, arg_functor_reducer), m_policy(Policy(0, arg_policy.m_num_tiles).set_chunk_size(1)), - m_functor_reducer(arg_functor_reducer), m_result_ptr(arg_view.data()), m_force_synchronous(!arg_view.impl_track().has_record()) { static_assert( diff --git a/core/src/OpenMP/Kokkos_OpenMP_Parallel.hpp b/core/src/OpenMP/Kokkos_OpenMP_Parallel.hpp index af1ba9543b..c3dfb69f59 100644 --- a/core/src/OpenMP/Kokkos_OpenMP_Parallel.hpp +++ b/core/src/OpenMP/Kokkos_OpenMP_Parallel.hpp @@ -467,13 +467,11 @@ class ParallelReduce; + using iterate_type = typename Kokkos::Impl::HostIterateTile< + MDRangePolicy, CombinedFunctorReducerType, WorkTag, reference_type>; OpenMPInternal* m_instance; const iterate_type m_iter; - const ReducerType m_reducer; const pointer_type m_result_ptr; inline void exec_range(const Member ibeg, const Member iend, @@ -485,7 +483,8 @@ class ParallelReduceacquire_lock(); @@ -504,11 +503,11 @@ class ParallelReduceget_thread_data(0)->pool_reduce_local()); - reference_type update = m_reducer.init(ptr); + reference_type update = reducer.init(ptr); ParallelReduce::exec_range(0, m_iter.m_rp.m_num_tiles, update); - m_reducer.final(ptr); + reducer.final(ptr); m_instance->release_lock(); @@ -533,7 +532,7 @@ class ParallelReduce(data.pool_reduce_local())); std::pair range(0, 0); @@ -554,15 +553,15 @@ class ParallelReduceget_thread_data(0)->pool_reduce_local()); for (int i = 1; i < pool_size; ++i) { - m_reducer.join(ptr, - reinterpret_cast( - m_instance->get_thread_data(i)->pool_reduce_local())); + reducer.join(ptr, + reinterpret_cast( + m_instance->get_thread_data(i)->pool_reduce_local())); } - m_reducer.final(ptr); + reducer.final(ptr); if (m_result_ptr) { - const int n = m_reducer.value_count(); + const int n = reducer.value_count(); for (int j = 0; j < n; ++j) { m_result_ptr[j] = ptr[j]; @@ -578,8 +577,7 @@ class ParallelReduce; + using iterate_type = typename Kokkos::Impl::HostIterateTile< + MDRangePolicy, CombinedFunctorReducerType, WorkTag, reference_type>; const iterate_type m_iter; - const ReducerType m_reducer; const pointer_type m_result_ptr; inline void exec(reference_type update) const { @@ -98,7 +96,8 @@ class ParallelReducem_thread_team_data.pool_reduce_local()); - reference_type update = m_reducer.init(ptr); + reference_type update = reducer.init(ptr); this->exec(update); - m_reducer.final(ptr); + reducer.final(ptr); } template ParallelReduce(const CombinedFunctorReducerType& arg_functor_reducer, const MDRangePolicy& arg_policy, const ViewType& arg_result_view) - : m_iter(arg_policy, arg_functor_reducer.get_functor()), - m_reducer(arg_functor_reducer.get_reducer()), + : m_iter(arg_policy, arg_functor_reducer), m_result_ptr(arg_result_view.data()) { static_assert(Kokkos::is_view::value, "Kokkos::Serial reduce result must be a View"); diff --git a/core/src/Threads/Kokkos_Threads_Parallel_MDRange.hpp b/core/src/Threads/Kokkos_Threads_Parallel_MDRange.hpp index 3ba8d27f5c..9d06249082 100644 --- a/core/src/Threads/Kokkos_Threads_Parallel_MDRange.hpp +++ b/core/src/Threads/Kokkos_Threads_Parallel_MDRange.hpp @@ -128,12 +128,10 @@ class ParallelReduce; + using iterate_type = typename Kokkos::Impl::HostIterateTile< + MDRangePolicy, CombinedFunctorReducerType, WorkTag, reference_type>; const iterate_type m_iter; - const ReducerType m_reducer; const pointer_type m_result_ptr; inline void exec_range(const Member &ibeg, const Member &iend, @@ -156,11 +154,12 @@ class ParallelReduce(exec.reduce_memory()))); + reducer.init(static_cast(exec.reduce_memory()))); - exec.fan_in_reduce(self.m_reducer); + exec.fan_in_reduce(reducer); } template @@ -178,6 +177,7 @@ class ParallelReduce(exec.reduce_memory())); while (work_index != -1) { @@ -192,7 +192,8 @@ class ParallelReduce::value, "Kokkos::Threads reduce result must be a View"); diff --git a/core/src/impl/KokkosExp_Host_IterateTile.hpp b/core/src/impl/KokkosExp_Host_IterateTile.hpp index 82604a24c2..a44ffefa6b 100644 --- a/core/src/impl/KokkosExp_Host_IterateTile.hpp +++ b/core/src/impl/KokkosExp_Host_IterateTile.hpp @@ -2093,8 +2093,8 @@ struct HostIterateTile::apply(val, m_func, full_tile, m_offset, m_rp.m_tile, - m_tiledims); + Tag>::apply(val, m_func.get_functor(), full_tile, m_offset, + m_rp.m_tile, m_tiledims); } #else