diff --git a/core/components/prefix_sum.hpp b/core/components/prefix_sum.hpp index 90848050590..b0cdf34018d 100644 --- a/core/components/prefix_sum.hpp +++ b/core/components/prefix_sum.hpp @@ -45,6 +45,22 @@ namespace gko { namespace kernels { +/** + * \fn prefix_sum + * Computes an exclusive prefix sum or exclusive scan of the input array. + * + * As with the standard definition of exclusive scan, the last entry of the + * input array is not read at all, but is written to. + * If the input is [3,4,1,9,100], it will be replaced by + * [0,3,7,8,17]. + * + * \tparam IndexType Type of entries to be scanned (summed). + * + * \param exec Executor on which to run the scan operation + * \param counts The input/output array to be scanned with the sum operation + * \param num_entries Size of the array, equal to one more than the number + * of entries to be summed. + */ #define GKO_DECLARE_PREFIX_SUM_KERNEL(IndexType) \ void prefix_sum(std::shared_ptr exec, \ IndexType *counts, size_type num_entries) diff --git a/omp/components/prefix_sum.cpp b/omp/components/prefix_sum.cpp index 8e2e964934b..6ac9d4e9291 100644 --- a/omp/components/prefix_sum.cpp +++ b/omp/components/prefix_sum.cpp @@ -33,21 +33,72 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/components/prefix_sum.hpp" +#include + + +#include + + +#include "core/base/allocator.hpp" + + namespace gko { namespace kernels { namespace omp { namespace components { +/* + * The last entry of the input array is never used, but is replaced. + */ template -void prefix_sum(std::shared_ptr exec, IndexType *counts, - size_type num_entries) +void prefix_sum(std::shared_ptr exec, + IndexType *const counts, const size_type num_entries) { - IndexType partial_sum{}; - for (size_type i = 0; i < num_entries; ++i) { - auto nnz = counts[i]; - counts[i] = partial_sum; - partial_sum += nnz; + // the operation only makes sense for arrays of size at least 2 + if (num_entries < 2) { + if (num_entries == 0) { + return; + } else { + counts[0] = 0; + return; + } + } + + const int nthreads = omp_get_max_threads(); + vector proc_sums(nthreads, 0, {exec}); + const size_type def_num_witems = (num_entries - 1) / nthreads + 1; + +#pragma omp parallel + { + const int thread_id = omp_get_thread_num(); + const size_type startidx = thread_id * def_num_witems; + const size_type endidx = + std::min(num_entries, (thread_id + 1) * def_num_witems); + + IndexType partial_sum{0}; + for (size_type i = startidx; i < endidx; ++i) { + auto nnz = counts[i]; + counts[i] = partial_sum; + partial_sum += nnz; + } + + proc_sums[thread_id] = partial_sum; + +#pragma omp barrier + +#pragma omp single + { + for (int i = 0; i < nthreads - 1; i++) { + proc_sums[i + 1] += proc_sums[i]; + } + } + + if (thread_id > 0) { + for (size_type i = startidx; i < endidx; i++) { + counts[i] += proc_sums[thread_id - 1]; + } + } } } diff --git a/omp/test/components/prefix_sum.cpp b/omp/test/components/prefix_sum.cpp index 1322f8a7188..4dbd82e3e90 100644 --- a/omp/test/components/prefix_sum.cpp +++ b/omp/test/components/prefix_sum.cpp @@ -89,6 +89,13 @@ class PrefixSum : public ::testing::Test { TYPED_TEST_SUITE(PrefixSum, gko::test::IndexTypes); +TYPED_TEST(PrefixSum, TrivialCasesEqualReference) +{ + this->test(0); + this->test(1); +} + + TYPED_TEST(PrefixSum, SmallEqualsReference) { this->test(100); }