diff --git a/common/factorization/par_ilut_kernels.hpp.inc b/common/factorization/par_ilut_kernels.hpp.inc index b5b7809f2e1..321a0805175 100644 --- a/common/factorization/par_ilut_kernels.hpp.inc +++ b/common/factorization/par_ilut_kernels.hpp.inc @@ -247,13 +247,13 @@ template __device__ void abstract_threshold_filter( const IndexType *row_ptrs, const ValueType *vals, IndexType num_rows, - remove_complex threshold, bool is_lower, BeginCallback begin_cb, + remove_complex threshold, BeginCallback begin_cb, StepCallback step_cb, FinishCallback finish_cb) { auto tidx = threadIdx.x + blockDim.x * blockIdx.x; - auto row = tidx / config::warp_size; - auto lane = threadIdx.x % config::warp_size; auto warp = group::thread_block_tile(); + auto row = tidx / warp.size(); + auto lane = warp.thread_rank(); using lanemask = decltype(warp.ballot(true)); auto lane_prefix_mask = (lanemask(1) << lane) - 1; if (row >= num_rows) { @@ -264,7 +264,7 @@ __device__ void abstract_threshold_filter( auto begin = row_ptrs[row]; auto end = row_ptrs[row + 1]; begin_cb(row); - auto diag_idx = is_lower ? end - 1 : begin; + auto diag_idx = end - 1; auto num_steps = ceildiv(end - begin, config::warp_size); for (auto step = 0; step < num_steps; ++step) { auto idx = begin + lane + step * config::warp_size; @@ -273,7 +273,8 @@ __device__ void abstract_threshold_filter( } auto keep = idx < end && (abs(val) >= threshold || idx == diag_idx); auto mask = warp.ballot(keep); - step_cb(idx, keep, popcnt(mask), popcnt(mask & lane_prefix_mask)); + step_cb(row, idx, val, keep, popcnt(mask), + popcnt(mask & lane_prefix_mask)); } finish_cb(row, lane); } @@ -282,18 +283,18 @@ __device__ void abstract_threshold_filter( template __global__ __launch_bounds__(default_block_size) void threshold_filter_nnz( const IndexType *row_ptrs, const ValueType *vals, IndexType num_rows, - remove_complex threshold, IndexType *nnz, bool is_lower) + remove_complex threshold, IndexType *nnz) { IndexType count{}; - abstract_threshold_filter(row_ptrs, vals, num_rows, threshold, is_lower, - [](IndexType) {}, - [&](IndexType, bool, IndexType warp_count, - IndexType) { count += warp_count; }, - [&](IndexType row, IndexType lane) { - if (row < num_rows && lane == 0) { - nnz[row] = count; - } - }); + abstract_threshold_filter( + row_ptrs, vals, num_rows, threshold, [&](IndexType) { count = 0; }, + [&](IndexType, IndexType, ValueType, bool, IndexType warp_count, + IndexType) { count += warp_count; }, + [&](IndexType row, IndexType lane) { + if (row < num_rows && lane == 0) { + nnz[row] = count; + } + }); } @@ -302,20 +303,23 @@ __global__ __launch_bounds__(default_block_size) void threshold_filter( const IndexType *old_row_ptrs, const IndexType *old_col_idxs, const ValueType *old_vals, IndexType num_rows, remove_complex threshold, const IndexType *new_row_ptrs, - IndexType *new_col_idxs, ValueType *new_vals, bool is_lower) + IndexType *new_row_idxs, IndexType *new_col_idxs, ValueType *new_vals) { IndexType count{}; IndexType new_offset{}; abstract_threshold_filter( - old_row_ptrs, old_vals, num_rows, threshold, is_lower, - [&](IndexType row) { new_offset = new_row_ptrs[row]; }, - [&](IndexType idx, bool keep, IndexType warp_count, - IndexType warp_prefix_sum) { + old_row_ptrs, old_vals, num_rows, threshold, + [&](IndexType row) { + new_offset = new_row_ptrs[row]; + count = 0; + }, + [&](IndexType row, IndexType idx, ValueType val, bool keep, + IndexType warp_count, IndexType warp_prefix_sum) { if (keep) { auto new_idx = new_offset + warp_prefix_sum + count; + new_row_idxs[new_idx] = row; new_col_idxs[new_idx] = old_col_idxs[idx]; - // hopefully the compiler is able to remove this duplicate load - new_vals[new_idx] = old_vals[idx]; + new_vals[new_idx] = val; } count += warp_count; }, @@ -452,4 +456,64 @@ __global__ __launch_bounds__(default_block_size) void tri_spgeam_init( } +template +__global__ __launch_bounds__(default_block_size) void sweep( + const IndexType *a_row_ptrs, const IndexType *a_col_idxs, + const ValueType *a_vals, const IndexType *l_row_ptrs, + const IndexType *l_row_idxs, const IndexType *l_col_idxs, ValueType *l_vals, + IndexType l_nnz, const IndexType *u_col_ptrs, const IndexType *u_row_idxs, + const IndexType *u_col_idxs, ValueType *u_vals, IndexType u_nnz) +{ + auto tidx = (threadIdx.x + blockIdx.x * blockDim.x) / subwarp_size; + if (tidx >= l_nnz + u_nnz) { + return; + } + auto row = tidx < l_nnz ? l_row_idxs[tidx] : u_row_idxs[tidx - l_nnz]; + auto col = tidx < l_nnz ? l_col_idxs[tidx] : u_col_idxs[tidx - l_nnz]; + if (tidx < l_nnz && row == col) { + // don't update the diagonal twice + return; + } + auto subwarp = + group::tiled_partition(group::this_thread_block()); + auto a_row_begin = a_row_ptrs[row]; + auto a_row_size = a_row_ptrs[row + 1] - a_row_begin; + auto a_idx = + group_wide_search(a_row_begin, a_row_size, subwarp, + [&](IndexType i) { return a_col_idxs[i] >= col; }); + auto a_val = a_col_idxs[a_idx] == col ? a_vals[a_idx] : zero(); + auto l_row_begin = l_row_ptrs[row]; + auto l_row_size = l_row_ptrs[row + 1] - l_row_begin; + auto u_col_begin = u_col_ptrs[row]; + auto u_col_size = u_col_ptrs[row + 1] - u_col_begin; + ValueType sum{}; + ValueType last_product{}; + IndexType l_out_idx{}; + IndexType u_out_idx{}; + group_merge(l_col_idxs, l_row_begin, l_row_size, u_row_idxs, u_col_begin, + u_col_size, + [&](IndexType l_idx, ValueType l_col, IndexType u_idx, + ValueType u_row, IndexType) { + if (l_col == u_row) { + l_out_idx = l_idx; + u_out_idx = u_idx; + last_product = l_vals[l_idx] * u_vals[u_idx]; + sum += last_product; + } + }); + + if (row > col) { + auto to_write = sum / u_vals[u_col_ptrs[col + 1] - 1]; + if (::gko::isfinite(to_write)) { + l_vals[l_out_idx] = to_write; + } + } else { + auto to_write = sum; + if (::gko::isfinite(to_write)) { + u_vals[u_out_idx] = to_write; + } + } +} + + } // namespace kernel \ No newline at end of file diff --git a/core/factorization/par_ilut.cpp b/core/factorization/par_ilut.cpp index 4b150f15fc9..a3648d23a4d 100644 --- a/core/factorization/par_ilut.cpp +++ b/core/factorization/par_ilut.cpp @@ -122,6 +122,14 @@ struct ParIlutState { std::unique_ptr l_coo; // transposed upper factor U currently being updated std::unique_ptr u_transp_coo; + // temporary array for threshold selection + Array selection_tmp; + // temporary array for threshold selection + Array> selection_tmp2; + // strategy to be used by the lower factor + std::shared_ptr l_strategy; + // strategy to be used by the upper factor + std::shared_ptr u_strategy; ParIlutState(std::shared_ptr exec_in, const CsrMatrix *system_matrix_in, @@ -131,8 +139,8 @@ struct ParIlutState { system_matrix{system_matrix_in}, l{std::move(l_in)}, u{std::move(u_in)}, - l_row_idxs_array{exec}, - u_col_idxs_array{exec} + selection_tmp{exec}, + selection_tmp2{exec} { auto mtx_size = system_matrix->get_size(); auto u_nnz = u->get_num_stored_elements(); @@ -178,7 +186,7 @@ ParIlut::generate_l_u( ->convert_to(csr_system_matrix_unique_ptr.get()); csr_system_matrix = csr_system_matrix_unique_ptr.get(); } - if (!skip_sorting) { + if (!parameters_.skip_sorting) { if (csr_system_matrix_unique_ptr == nullptr) { csr_system_matrix_unique_ptr = CsrMatrix::create(exec); csr_system_matrix_unique_ptr->copy_from(csr_system_matrix); @@ -203,12 +211,12 @@ ParIlut::generate_l_u( auto l_nnz = static_cast(l_nnz_it); auto u_nnz = static_cast(u_nnz_it); - auto l = - Csr::create(exec, Array{exec, l_nnz}, - Array{exec, l_nnz}, std::move(l_row_ptrs_array)); - auto u = - Csr::create(exec, Array{exec, u_nnz}, - Array{exec, u_nnz}, std::move(u_row_ptrs_array)); + auto l = CsrMatrix::create(exec, Array{exec, l_nnz}, + Array{exec, l_nnz}, + std::move(l_row_ptrs_array)); + auto u = CsrMatrix::create(exec, Array{exec, u_nnz}, + Array{exec, u_nnz}, + std::move(u_row_ptrs_array)); // initialize L and U exec->run(make_initialize_l_u(csr_system_matrix, l.get(), u.get())); @@ -276,10 +284,12 @@ void ParIlutState::iterate() } else { // select threshold to remove smallest candidates remove_complex l_threshold{}; - exec->run(make_threshold_select(l_new.get(), l_nnz_limit, l_threshold)); + exec->run(make_threshold_select(l_new.get(), l_nnz_limit, l_threshold, + selection_tmp, selection_tmp2)); remove_complex u_threshold{}; - exec->run( - make_threshold_select(u_new_csc.get(), u_nnz_limit, u_threshold)); + exec->run(make_threshold_select(u_new_csc.get(), u_nnz_limit, + u_threshold, selection_tmp, + selection_tmp2)); // remove smallest candidates exec->run(make_threshold_filter(l_new.get(), l_threshold, l.get(), diff --git a/core/factorization/par_ilut_kernels.hpp b/core/factorization/par_ilut_kernels.hpp index 9f2b4276923..5275fde98a8 100644 --- a/core/factorization/par_ilut_kernels.hpp +++ b/core/factorization/par_ilut_kernels.hpp @@ -68,7 +68,8 @@ namespace kernels { #define GKO_DECLARE_PAR_ILUT_THRESHOLD_SELECT_KERNEL(ValueType, IndexType) \ void threshold_select(std::shared_ptr exec, \ const matrix::Csr *m, \ - IndexType rank, \ + IndexType rank, Array &tmp, \ + Array> &tmp2, \ remove_complex &threshold) #define GKO_DECLARE_PAR_ILUT_THRESHOLD_FILTER_KERNEL(ValueType, IndexType) \ void threshold_filter(std::shared_ptr exec, \ diff --git a/cuda/factorization/par_ilut_kernels.cu b/cuda/factorization/par_ilut_kernels.cu index fd6bb370d7c..702518f4397 100644 --- a/cuda/factorization/par_ilut_kernels.cu +++ b/cuda/factorization/par_ilut_kernels.cu @@ -43,6 +43,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/components/prefix_sum.hpp" +#include "core/matrix/coo_builder.hpp" +#include "core/matrix/csr_builder.hpp" #include "cuda/base/cusparse_bindings.hpp" #include "cuda/base/math.hpp" #include "cuda/base/pointer_mode_guard.hpp" @@ -108,28 +111,48 @@ void ssss_filter(const ValueType *values, IndexType size, template -remove_complex threshold_select( - std::shared_ptr exec, const ValueType *values, - IndexType size, IndexType rank) +void threshold_select(std::shared_ptr exec, + const matrix::Csr *m, + IndexType rank, Array &tmp1, + Array> &tmp2, + remove_complex &threshold) { + auto values = m->get_const_values(); + IndexType size = m->get_num_stored_elements(); using AbsType = remove_complex; constexpr auto bucket_count = kernel::searchtree_width; auto max_num_threads = ceildiv(size, items_per_thread); auto max_num_blocks = ceildiv(max_num_threads, default_block_size); - // we use the last entry to store the total element count - Array total_counts_array(exec, bucket_count + 1); - Array partial_counts_array(exec, bucket_count * max_num_blocks); - Array oracle_array(exec, size); - Array tree_array(exec, kernel::searchtree_size); - auto partial_counts = partial_counts_array.get_data(); - auto total_counts = total_counts_array.get_data(); - auto oracles = oracle_array.get_data(); - auto tree = tree_array.get_data(); + size_type tmp_size_totals = + ceildiv((bucket_count + 1) * sizeof(IndexType), sizeof(ValueType)); + size_type tmp_size_partials = ceildiv( + bucket_count * max_num_blocks * sizeof(IndexType), sizeof(ValueType)); + size_type tmp_size_oracles = + ceildiv(size * sizeof(unsigned char), sizeof(ValueType)); + size_type tmp_size_tree = + ceildiv(kernel::searchtree_size * sizeof(AbsType), sizeof(ValueType)); + size_type tmp_size_vals = + size / bucket_count * 4; // pessimistic estimate for temporary storage + size_type tmp_size = + tmp_size_totals + tmp_size_partials + tmp_size_oracles + tmp_size_tree; + tmp1.resize_and_reset(tmp_size); + tmp2.resize_and_reset(tmp_size_vals); + + auto total_counts = reinterpret_cast(tmp1.get_data()); + auto partial_counts = + reinterpret_cast(tmp1.get_data() + tmp_size_totals); + auto oracles = reinterpret_cast( + tmp1.get_data() + tmp_size_totals + tmp_size_partials); + auto tree = + reinterpret_cast(tmp1.get_data() + tmp_size_totals + + tmp_size_partials + tmp_size_oracles); ssss_count(values, size, tree, oracles, partial_counts, total_counts); // determine bucket with correct rank + auto total_counts_array = + Array::view(exec, bucket_count + 1, total_counts); Array splitter_ranks_array(exec->get_master(), total_counts_array); auto splitter_ranks = splitter_ranks_array.get_const_data(); @@ -139,19 +162,21 @@ remove_complex threshold_select( auto bucket_size = splitter_ranks[bucket + 1] - splitter_ranks[bucket]; rank -= splitter_ranks[bucket]; - Array tmp_out_array(exec, bucket_size); - Array tmp_in_array(exec, bucket_size); - auto tmp_out = tmp_out_array.get_data(); - auto tmp_in = tmp_in_array.get_const_data(); + if (bucket_size * 2 > tmp_size_vals) { + // we need to reallocate tmp2 + tmp2.resize_and_reset(bucket_size * 2); + } + auto tmp21 = tmp2.get_data(); + auto tmp22 = tmp2.get_data() + bucket_size; // extract target bucket - ssss_filter(values, size, oracles, partial_counts, bucket, tmp_out); + ssss_filter(values, size, oracles, partial_counts, bucket, tmp22); // recursively select from smaller buckets int step{}; while (bucket_size > kernel::basecase_size) { - std::swap(tmp_out_array, tmp_in_array); - tmp_out = tmp_out_array.get_data(); - tmp_in = tmp_in_array.get_const_data(); + std::swap(tmp21, tmp22); + const auto *tmp_in = tmp21; + auto tmp_out = tmp22; ssss_count(tmp_in, bucket_size, tree, oracles, partial_counts, total_counts); @@ -169,23 +194,23 @@ remove_complex threshold_select( // 256^5 = 2^40. fall back to standard library algorithm in that case. ++step; if (step > 5) { - Array cpu_out_array{exec->get_master(), tmp_out_array}; + Array cpu_out_array{ + exec->get_master(), + Array::view(exec, bucket_size, tmp_out)}; auto begin = cpu_out_array.get_data(); auto end = begin + bucket_size; auto middle = begin + rank; std::nth_element(begin, middle, end); - return *middle; + threshold = *middle; + return; } } // base case - Array result_array{exec, 1}; + auto out_ptr = reinterpret_cast(tmp1.get_data()); kernel::basecase_select<<<1, kernel::basecase_block_size>>>( - tmp_out, bucket_size, rank, result_array.get_data()); - AbsType result{}; - exec->get_master()->copy_from(exec.get(), 1, result_array.get_const_data(), - &result); - return result; + tmp22, bucket_size, rank, out_ptr); + exec->get_master()->copy_from(exec.get(), 1, out_ptr, &threshold); } @@ -197,9 +222,8 @@ template void threshold_filter(std::shared_ptr exec, const matrix::Csr *a, remove_complex threshold, - Array &new_row_ptrs_array, - Array &new_col_idxs_array, - Array &new_vals_array, bool is_lower) + matrix::Csr *m_out, + matrix::Coo *m_out_coo) { auto old_row_ptrs = a->get_const_row_ptrs(); auto old_col_idxs = a->get_const_col_idxs(); @@ -208,37 +232,34 @@ void threshold_filter(std::shared_ptr exec, auto num_rows = IndexType(a->get_size()[0]); auto block_size = default_block_size / config::warp_size; auto num_blocks = ceildiv(num_rows, block_size); - new_row_ptrs_array.resize_and_reset(num_rows + 1); - auto new_row_ptrs = new_row_ptrs_array.get_data(); - auto block_dim = dim3(config::warp_size, num_blocks); - kernel::threshold_filter_nnz<<>>( - old_row_ptrs, as_cuda_type(old_vals), num_rows, threshold, new_row_ptrs, - is_lower); + auto new_row_ptrs = m_out->get_row_ptrs(); + kernel::threshold_filter_nnz<<>>( + old_row_ptrs, as_cuda_type(old_vals), num_rows, threshold, + new_row_ptrs); // build row pointers - auto num_row_ptrs = num_rows + 1; - auto num_reduce_blocks = ceildiv(num_row_ptrs, default_block_size); - Array block_counts_array(exec, num_reduce_blocks); - auto block_counts = block_counts_array.get_data(); - - start_prefix_sum - <<>>(num_row_ptrs, new_row_ptrs, - block_counts); - finalize_prefix_sum - <<>>(num_row_ptrs, new_row_ptrs, - block_counts); + prefix_sum(exec, new_row_ptrs, num_rows + 1); // build matrix - IndexType num_nnz{}; + IndexType new_nnz{}; exec->get_master()->copy_from(exec.get(), 1, new_row_ptrs + num_rows, - &num_nnz); - new_col_idxs_array.resize_and_reset(num_nnz); - new_vals_array.resize_and_reset(num_nnz); - auto new_col_idxs = new_col_idxs_array.get_data(); - auto new_vals = new_vals_array.get_data(); - kernel::threshold_filter<<>>( + &new_nnz); + // resize arrays and update aliases + matrix::CsrBuilder builder{m_out}; + builder.get_col_idx_array().resize_and_reset(new_nnz); + builder.get_value_array().resize_and_reset(new_nnz); + auto new_col_idxs = m_out->get_col_idxs(); + auto new_vals = m_out->get_values(); + matrix::CooBuilder coo_builder{m_out_coo}; + coo_builder.get_row_idx_array().resize_and_reset(new_nnz); + coo_builder.get_col_idx_array() = + Array::view(exec, new_nnz, new_col_idxs); + coo_builder.get_value_array() = + Array::view(exec, new_nnz, new_vals); + auto new_row_idxs = m_out_coo->get_row_idxs(); + kernel::threshold_filter<<>>( old_row_ptrs, old_col_idxs, as_cuda_type(old_vals), num_rows, threshold, - new_row_ptrs, new_col_idxs, as_cuda_type(new_vals), is_lower); + new_row_ptrs, new_row_idxs, new_col_idxs, as_cuda_type(new_vals)); } diff --git a/cuda/test/factorization/par_ilut_kernels.cpp b/cuda/test/factorization/par_ilut_kernels.cpp index 96a7c3053f9..926f88102b8 100644 --- a/cuda/test/factorization/par_ilut_kernels.cpp +++ b/cuda/test/factorization/par_ilut_kernels.cpp @@ -130,45 +130,49 @@ class ParIlut : public ::testing::Test { template void test_select(const std::unique_ptr &mtx, - const std::unique_ptr &dmtx, index_type rank) + const std::unique_ptr &dmtx, index_type rank, + value_type tolerance = 0.0) { auto size = index_type(mtx->get_num_stored_elements()); + using ValueType = typename Mtx::value_type; - auto res = - gko::kernels::reference::par_ilut_factorization::threshold_select( - ref, mtx->get_const_values(), size, rank); - auto dres = - gko::kernels::cuda::par_ilut_factorization::threshold_select( - cuda, dmtx->get_const_values(), size, rank); + gko::remove_complex res{}; + gko::remove_complex dres{}; + gko::Array tmp(ref); + gko::Array> tmp2(ref); + gko::Array dtmp(cuda); + gko::Array> dtmp2(cuda); - ASSERT_EQ(res, dres); + gko::kernels::reference::par_ilut_factorization::threshold_select( + ref, mtx.get(), rank, tmp, tmp2, res); + gko::kernels::cuda::par_ilut_factorization::threshold_select( + cuda, dmtx.get(), rank, dtmp, dtmp2, dres); + + ASSERT_NEAR(res, dres, tolerance); } - template + template > void test_filter(const std::unique_ptr &mtx, - const std::unique_ptr &dmtx, value_type threshold, - bool lower) + const std::unique_ptr &dmtx, value_type threshold) { - gko::Array new_row_ptrs(ref); - gko::Array new_col_idxs(ref); - gko::Array new_vals(ref); - gko::Array dnew_row_ptrs(cuda); - gko::Array dnew_col_idxs(cuda); - gko::Array dnew_vals(cuda); + auto res = Mtx::create(ref, mtx_size); + auto dres = Mtx::create(cuda, mtx_size); + auto res_coo = Coo::create(ref, mtx_size); + auto dres_coo = Coo::create(cuda, mtx_size); gko::kernels::reference::par_ilut_factorization::threshold_filter( - ref, mtx.get(), threshold, new_row_ptrs, new_col_idxs, new_vals, - lower); + ref, mtx.get(), threshold, res.get(), res_coo.get()); gko::kernels::cuda::par_ilut_factorization::threshold_filter( - cuda, dmtx.get(), threshold, dnew_row_ptrs, dnew_col_idxs, - dnew_vals, lower); - auto res = - Mtx::create(ref, mtx_size, new_vals, new_col_idxs, new_row_ptrs); - auto dres = Mtx::create(cuda, mtx_size, dnew_vals, dnew_col_idxs, - dnew_row_ptrs); + cuda, dmtx.get(), threshold, dres.get(), dres_coo.get()); GKO_ASSERT_MTX_NEAR(res, dres, 0); GKO_ASSERT_MTX_EQ_SPARSITY(res, dres); + GKO_ASSERT_MTX_NEAR(res, res_coo, 0); + GKO_ASSERT_MTX_EQ_SPARSITY(res, res_coo); + GKO_ASSERT_MTX_NEAR(dres, dres_coo, 0); + GKO_ASSERT_MTX_EQ_SPARSITY(dres, dres_coo); } std::shared_ptr ref; @@ -222,92 +226,56 @@ TEST_F(ParIlut, KernelThresholdSelectMaxIsEquivalentToRef) TEST_F(ParIlut, KernelComplexThresholdSelectIsEquivalentToRef) { test_select(mtx_l_complex, dmtx_l_complex, - mtx_l->get_num_stored_elements() / 3); + mtx_l_complex->get_num_stored_elements() / 3, 1e-14); } TEST_F(ParIlut, KernelComplexThresholdSelectMinIsEquivalentToRef) { - test_select(mtx_l_complex, dmtx_l_complex, 0); + test_select(mtx_l_complex, dmtx_l_complex, 0, 1e-14); } TEST_F(ParIlut, KernelComplexThresholdSelectMaxLowerIsEquivalentToRef) { test_select(mtx_l_complex, dmtx_l_complex, - mtx_l->get_num_stored_elements() - 1); + mtx_l_complex->get_num_stored_elements() - 1, 1e-14); } TEST_F(ParIlut, KernelThresholdFilterLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 0.5, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 0.5, false); + test_filter(mtx_l, dmtx_l, 0.5); } TEST_F(ParIlut, KernelThresholdFilterNoneLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 0, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterNoneUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 0, false); + test_filter(mtx_l, dmtx_l, 0); } TEST_F(ParIlut, KernelThresholdFilterAllLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 1e6, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterAllUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 1e6, false); + test_filter(mtx_l, dmtx_l, 1e6); } TEST_F(ParIlut, KernelComplexThresholdFilterLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 0.5, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 0.5, false); + test_filter(mtx_l_complex, dmtx_l_complex, 0.5); } TEST_F(ParIlut, KernelComplexThresholdFilterNoneLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 0, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterNoneUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 0, false); + test_filter(mtx_l_complex, dmtx_l_complex, 0); } TEST_F(ParIlut, KernelComplexThresholdFilterAllLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 1e6, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterAllUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 1e6, false); + test_filter(mtx_l_complex, dmtx_l_complex, 1e6); } diff --git a/hip/factorization/par_ilut_kernels.hip.cpp b/hip/factorization/par_ilut_kernels.hip.cpp index da9a60f0fdb..7dbcb7d7492 100644 --- a/hip/factorization/par_ilut_kernels.hip.cpp +++ b/hip/factorization/par_ilut_kernels.hip.cpp @@ -36,6 +36,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include + + #include #include #include @@ -43,6 +46,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/components/prefix_sum.hpp" +#include "core/matrix/coo_builder.hpp" +#include "core/matrix/csr_builder.hpp" #include "core/matrix/csr_kernels.hpp" #include "hip/base/math.hip.hpp" #include "hip/components/atomic.hip.hpp" @@ -111,28 +117,48 @@ void ssss_filter(const ValueType *values, IndexType size, template -remove_complex threshold_select( - std::shared_ptr exec, const ValueType *values, - IndexType size, IndexType rank) +void threshold_select(std::shared_ptr exec, + const matrix::Csr *m, + IndexType rank, Array &tmp1, + Array> &tmp2, + remove_complex &threshold) { + auto values = m->get_const_values(); + IndexType size = m->get_num_stored_elements(); using AbsType = remove_complex; constexpr auto bucket_count = kernel::searchtree_width; auto max_num_threads = ceildiv(size, items_per_thread); auto max_num_blocks = ceildiv(max_num_threads, default_block_size); - // we use the last entry to store the total element count - Array total_counts_array(exec, bucket_count + 1); - Array partial_counts_array(exec, bucket_count * max_num_blocks); - Array oracle_array(exec, size); - Array tree_array(exec, kernel::searchtree_size); - auto partial_counts = partial_counts_array.get_data(); - auto total_counts = total_counts_array.get_data(); - auto oracles = oracle_array.get_data(); - auto tree = tree_array.get_data(); + size_type tmp_size_totals = + ceildiv((bucket_count + 1) * sizeof(IndexType), sizeof(ValueType)); + size_type tmp_size_partials = ceildiv( + bucket_count * max_num_blocks * sizeof(IndexType), sizeof(ValueType)); + size_type tmp_size_oracles = + ceildiv(size * sizeof(unsigned char), sizeof(ValueType)); + size_type tmp_size_tree = + ceildiv(kernel::searchtree_size * sizeof(AbsType), sizeof(ValueType)); + size_type tmp_size_vals = + size / bucket_count * 4; // pessimistic estimate for temporary storage + size_type tmp_size = + tmp_size_totals + tmp_size_partials + tmp_size_oracles + tmp_size_tree; + tmp1.resize_and_reset(tmp_size); + tmp2.resize_and_reset(tmp_size_vals); + + auto total_counts = reinterpret_cast(tmp1.get_data()); + auto partial_counts = + reinterpret_cast(tmp1.get_data() + tmp_size_totals); + auto oracles = reinterpret_cast( + tmp1.get_data() + tmp_size_totals + tmp_size_partials); + auto tree = + reinterpret_cast(tmp1.get_data() + tmp_size_totals + + tmp_size_partials + tmp_size_oracles); ssss_count(values, size, tree, oracles, partial_counts, total_counts); // determine bucket with correct rank + auto total_counts_array = + Array::view(exec, bucket_count + 1, total_counts); Array splitter_ranks_array(exec->get_master(), total_counts_array); auto splitter_ranks = splitter_ranks_array.get_const_data(); @@ -142,19 +168,21 @@ remove_complex threshold_select( auto bucket_size = splitter_ranks[bucket + 1] - splitter_ranks[bucket]; rank -= splitter_ranks[bucket]; - Array tmp_out_array(exec, bucket_size); - Array tmp_in_array(exec, bucket_size); - auto tmp_out = tmp_out_array.get_data(); - auto tmp_in = tmp_in_array.get_const_data(); + if (bucket_size * 2 > tmp_size_vals) { + // we need to reallocate tmp2 + tmp2.resize_and_reset(bucket_size * 2); + } + auto tmp21 = tmp2.get_data(); + auto tmp22 = tmp2.get_data() + bucket_size; // extract target bucket - ssss_filter(values, size, oracles, partial_counts, bucket, tmp_out); + ssss_filter(values, size, oracles, partial_counts, bucket, tmp22); // recursively select from smaller buckets int step{}; while (bucket_size > kernel::basecase_size) { - std::swap(tmp_out_array, tmp_in_array); - tmp_out = tmp_out_array.get_data(); - tmp_in = tmp_in_array.get_const_data(); + std::swap(tmp21, tmp22); + const auto *tmp_in = tmp21; + auto tmp_out = tmp22; ssss_count(tmp_in, bucket_size, tree, oracles, partial_counts, total_counts); @@ -172,24 +200,24 @@ remove_complex threshold_select( // 256^5 = 2^40. fall back to standard library algorithm in that case. ++step; if (step > 5) { - Array cpu_out_array{exec->get_master(), tmp_out_array}; + Array cpu_out_array{ + exec->get_master(), + Array::view(exec, bucket_size, tmp_out)}; auto begin = cpu_out_array.get_data(); auto end = begin + bucket_size; auto middle = begin + rank; std::nth_element(begin, middle, end); - return *middle; + threshold = *middle; + return; } } // base case - Array result_array{exec, 1}; + auto out_ptr = reinterpret_cast(tmp1.get_data()); hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel::basecase_select), dim3(1), - dim3(kernel::basecase_block_size), 0, 0, tmp_out, - bucket_size, rank, result_array.get_data()); - AbsType result{}; - exec->get_master()->copy_from(exec.get(), 1, result_array.get_const_data(), - &result); - return result; + dim3(kernel::basecase_block_size), 0, 0, tmp22, + bucket_size, rank, out_ptr); + exec->get_master()->copy_from(exec.get(), 1, out_ptr, &threshold); } @@ -201,49 +229,47 @@ template void threshold_filter(std::shared_ptr exec, const matrix::Csr *a, remove_complex threshold, - Array &new_row_ptrs_array, - Array &new_col_idxs_array, - Array &new_vals_array, bool is_lower) + matrix::Csr *m_out, + matrix::Coo *m_out_coo) { auto old_row_ptrs = a->get_const_row_ptrs(); auto old_col_idxs = a->get_const_col_idxs(); auto old_vals = a->get_const_values(); // compute nnz for each row auto num_rows = IndexType(a->get_size()[0]); - auto num_blocks = ceildiv(num_rows, default_block_size / config::warp_size); - new_row_ptrs_array.resize_and_reset(num_rows + 1); - auto new_row_ptrs = new_row_ptrs_array.get_data(); + auto block_size = default_block_size / config::warp_size; + auto num_blocks = ceildiv(num_rows, block_size); + auto new_row_ptrs = m_out->get_row_ptrs(); hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel::threshold_filter_nnz), dim3(num_blocks), dim3(default_block_size), 0, 0, old_row_ptrs, as_hip_type(old_vals), num_rows, threshold, - new_row_ptrs, is_lower); + new_row_ptrs); // build row pointers - auto num_row_ptrs = num_rows + 1; - auto num_reduce_blocks = ceildiv(num_row_ptrs, default_block_size); - Array block_counts_array(exec, num_reduce_blocks); - auto block_counts = block_counts_array.get_data(); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(start_prefix_sum), - dim3(num_reduce_blocks), dim3(default_block_size), 0, 0, - num_row_ptrs, new_row_ptrs, block_counts); - hipLaunchKernelGGL(HIP_KERNEL_NAME(finalize_prefix_sum), - dim3(num_reduce_blocks), dim3(default_block_size), 0, 0, - num_row_ptrs, new_row_ptrs, block_counts); + prefix_sum(exec, new_row_ptrs, num_rows + 1); // build matrix - IndexType num_nnz{}; + IndexType new_nnz{}; exec->get_master()->copy_from(exec.get(), 1, new_row_ptrs + num_rows, - &num_nnz); - new_col_idxs_array.resize_and_reset(num_nnz); - new_vals_array.resize_and_reset(num_nnz); - auto new_col_idxs = new_col_idxs_array.get_data(); - auto new_vals = new_vals_array.get_data(); + &new_nnz); + // resize arrays and update aliases + matrix::CsrBuilder builder{m_out}; + builder.get_col_idx_array().resize_and_reset(new_nnz); + builder.get_value_array().resize_and_reset(new_nnz); + auto new_col_idxs = m_out->get_col_idxs(); + auto new_vals = m_out->get_values(); + matrix::CooBuilder coo_builder{m_out_coo}; + coo_builder.get_row_idx_array().resize_and_reset(new_nnz); + coo_builder.get_col_idx_array() = + Array::view(exec, new_nnz, new_col_idxs); + coo_builder.get_value_array() = + Array::view(exec, new_nnz, new_vals); + auto new_row_idxs = m_out_coo->get_row_idxs(); hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel::threshold_filter), dim3(num_blocks), dim3(default_block_size), 0, 0, old_row_ptrs, old_col_idxs, as_hip_type(old_vals), - num_rows, threshold, new_row_ptrs, new_col_idxs, - as_hip_type(new_vals), is_lower); + num_rows, threshold, new_row_ptrs, new_row_idxs, + new_col_idxs, as_hip_type(new_vals)); } diff --git a/hip/test/factorization/par_ilut_kernels.hip.cpp b/hip/test/factorization/par_ilut_kernels.hip.cpp index 99f2396828e..d96318b2bbe 100644 --- a/hip/test/factorization/par_ilut_kernels.hip.cpp +++ b/hip/test/factorization/par_ilut_kernels.hip.cpp @@ -130,48 +130,49 @@ class ParIlut : public ::testing::Test { template void test_select(const std::unique_ptr &mtx, - const std::unique_ptr &dmtx, index_type rank) + const std::unique_ptr &dmtx, index_type rank, + value_type tolerance = 0.0) { auto size = index_type(mtx->get_num_stored_elements()); + using ValueType = typename Mtx::value_type; - auto res = - gko::kernels::reference::par_ilut_factorization::threshold_select( - ref, mtx->get_const_values(), size, rank); - auto dres = gko::kernels::hip::par_ilut_factorization::threshold_select( - hip, dmtx->get_const_values(), size, rank); - - if (gko::is_complex_s::value) { - ASSERT_NEAR(res, dres, 1e-14); - } else { - ASSERT_EQ(res, dres); - } + gko::remove_complex res{}; + gko::remove_complex dres{}; + gko::Array tmp(ref); + gko::Array> tmp2(ref); + gko::Array dtmp(hip); + gko::Array> dtmp2(hip); + + gko::kernels::reference::par_ilut_factorization::threshold_select( + ref, mtx.get(), rank, tmp, tmp2, res); + gko::kernels::hip::par_ilut_factorization::threshold_select( + hip, dmtx.get(), rank, dtmp, dtmp2, dres); + + ASSERT_NEAR(res, dres, tolerance); } - template + template > void test_filter(const std::unique_ptr &mtx, - const std::unique_ptr &dmtx, value_type threshold, - bool lower) + const std::unique_ptr &dmtx, value_type threshold) { - gko::Array new_row_ptrs(ref); - gko::Array new_col_idxs(ref); - gko::Array new_vals(ref); - gko::Array dnew_row_ptrs(hip); - gko::Array dnew_col_idxs(hip); - gko::Array dnew_vals(hip); + auto res = Mtx::create(ref, mtx_size); + auto dres = Mtx::create(hip, mtx_size); + auto res_coo = Coo::create(ref, mtx_size); + auto dres_coo = Coo::create(hip, mtx_size); gko::kernels::reference::par_ilut_factorization::threshold_filter( - ref, mtx.get(), threshold, new_row_ptrs, new_col_idxs, new_vals, - lower); + ref, mtx.get(), threshold, res.get(), res_coo.get()); gko::kernels::hip::par_ilut_factorization::threshold_filter( - hip, dmtx.get(), threshold, dnew_row_ptrs, dnew_col_idxs, dnew_vals, - lower); - auto res = - Mtx::create(ref, mtx_size, new_vals, new_col_idxs, new_row_ptrs); - auto dres = - Mtx::create(hip, mtx_size, dnew_vals, dnew_col_idxs, dnew_row_ptrs); + hip, dmtx.get(), threshold, dres.get(), dres_coo.get()); GKO_ASSERT_MTX_NEAR(res, dres, 0); GKO_ASSERT_MTX_EQ_SPARSITY(res, dres); + GKO_ASSERT_MTX_NEAR(res, res_coo, 0); + GKO_ASSERT_MTX_EQ_SPARSITY(res, res_coo); + GKO_ASSERT_MTX_NEAR(dres, dres_coo, 0); + GKO_ASSERT_MTX_EQ_SPARSITY(dres, dres_coo); } std::shared_ptr ref; @@ -225,92 +226,56 @@ TEST_F(ParIlut, KernelThresholdSelectMaxIsEquivalentToRef) TEST_F(ParIlut, KernelComplexThresholdSelectIsEquivalentToRef) { test_select(mtx_l_complex, dmtx_l_complex, - mtx_l_complex->get_num_stored_elements() / 3); + mtx_l_complex->get_num_stored_elements() / 3, 1e-14); } TEST_F(ParIlut, KernelComplexThresholdSelectMinIsEquivalentToRef) { - test_select(mtx_l_complex, dmtx_l_complex, 0); + test_select(mtx_l_complex, dmtx_l_complex, 0, 1e-14); } TEST_F(ParIlut, KernelComplexThresholdSelectMaxLowerIsEquivalentToRef) { test_select(mtx_l_complex, dmtx_l_complex, - mtx_l_complex->get_num_stored_elements() - 1); + mtx_l_complex->get_num_stored_elements() - 1, 1e-14); } TEST_F(ParIlut, KernelThresholdFilterLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 0.5, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 0.5, false); + test_filter(mtx_l, dmtx_l, 0.5); } TEST_F(ParIlut, KernelThresholdFilterNoneLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 0, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterNoneUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 0, false); + test_filter(mtx_l, dmtx_l, 0); } TEST_F(ParIlut, KernelThresholdFilterAllLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 1e6, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterAllUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 1e6, false); + test_filter(mtx_l, dmtx_l, 1e6); } TEST_F(ParIlut, KernelComplexThresholdFilterLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 0.5, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 0.5, false); + test_filter(mtx_l_complex, dmtx_l_complex, 0.5); } TEST_F(ParIlut, KernelComplexThresholdFilterNoneLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 0, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterNoneUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 0, false); + test_filter(mtx_l_complex, dmtx_l_complex, 0); } TEST_F(ParIlut, KernelComplexThresholdFilterAllLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 1e6, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterAllUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 1e6, false); + test_filter(mtx_l_complex, dmtx_l_complex, 1e6); } diff --git a/omp/factorization/par_ilut_kernels.cpp b/omp/factorization/par_ilut_kernels.cpp index bb3480bae8e..665f9623d93 100644 --- a/omp/factorization/par_ilut_kernels.cpp +++ b/omp/factorization/par_ilut_kernels.cpp @@ -43,6 +43,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/components/prefix_sum.hpp" +#include "core/matrix/coo_builder.hpp" +#include "core/matrix/csr_builder.hpp" + + namespace gko { namespace kernels { namespace omp { @@ -55,19 +60,23 @@ namespace par_ilut_factorization { template -remove_complex threshold_select( - std::shared_ptr exec, const ValueType *values, - IndexType size, IndexType rank) +void threshold_select(std::shared_ptr exec, + const matrix::Csr *m, + IndexType rank, Array &tmp, + Array> &, + remove_complex &threshold) { - Array data(exec, size); - std::copy_n(values, size, data.get_data()); + auto values = m->get_const_values(); + IndexType size = m->get_num_stored_elements(); + tmp.resize_and_reset(size); + std::copy_n(values, size, tmp.get_data()); - auto begin = data.get_data(); + auto begin = tmp.get_data(); auto target = begin + rank; auto end = begin + size; std::nth_element(begin, target, end, [](ValueType a, ValueType b) { return abs(a) < abs(b); }); - return abs(*target); + threshold = abs(*target); } @@ -79,9 +88,8 @@ template void threshold_filter(std::shared_ptr exec, const matrix::Csr *a, remove_complex threshold, - Array &new_row_ptrs_array, - Array &new_col_idxs_array, - Array &new_vals_array, bool is_lower) + matrix::Csr *m_out, + matrix::Coo *m_out_coo) { auto num_rows = a->get_size()[0]; auto row_ptrs = a->get_const_row_ptrs(); @@ -89,57 +97,56 @@ void threshold_filter(std::shared_ptr exec, auto vals = a->get_const_values(); // first sweep: count nnz for each row - new_row_ptrs_array.resize_and_reset(num_rows + 1); - auto new_row_ptrs = new_row_ptrs_array.get_data(); + auto new_row_ptrs = m_out->get_row_ptrs(); #pragma omp parallel for for (size_type row = 0; row < num_rows; ++row) { // ignoring diagonal entries: - // lower triangular part has the diagonal last - // upper triangular part has the diagonal first - size_type begin = row_ptrs[row] + !is_lower; - size_type end = row_ptrs[row + 1] - is_lower; - new_row_ptrs[row + 1] = + // lower triangular matrix has the diagonal last + size_type begin = row_ptrs[row]; + size_type end = row_ptrs[row + 1] - 1; + new_row_ptrs[row] = std::count_if(vals + begin, vals + end, [&](ValueType v) { return abs(v) >= threshold; }); // add diagonal - new_row_ptrs[row + 1]++; + new_row_ptrs[row]++; } // build row pointers: exclusive scan (thus the + 1) - new_row_ptrs[0] = 0; - std::partial_sum(new_row_ptrs + 1, new_row_ptrs + num_rows + 1, - new_row_ptrs + 1); + prefix_sum(exec, new_row_ptrs, num_rows + 1); // second sweep: accumulate non-zeros auto new_nnz = new_row_ptrs[num_rows]; - new_col_idxs_array.resize_and_reset(new_nnz); - new_vals_array.resize_and_reset(new_nnz); - auto new_col_idxs = new_col_idxs_array.get_data(); - auto new_vals = new_vals_array.get_data(); + // resize arrays and update aliases + matrix::CsrBuilder builder{m_out}; + builder.get_col_idx_array().resize_and_reset(new_nnz); + builder.get_value_array().resize_and_reset(new_nnz); + auto new_col_idxs = m_out->get_col_idxs(); + auto new_vals = m_out->get_values(); + matrix::CooBuilder coo_builder{m_out_coo}; + coo_builder.get_row_idx_array().resize_and_reset(new_nnz); + coo_builder.get_col_idx_array() = + Array::view(exec, new_nnz, new_col_idxs); + coo_builder.get_value_array() = + Array::view(exec, new_nnz, new_vals); + auto new_row_idxs = m_out_coo->get_row_idxs(); #pragma omp parallel for for (size_type row = 0; row < num_rows; ++row) { // ignoring diagonal entries: - // lower triangular part has the diagonal last - // upper triangular part has the diagonal first - size_type new_begin = new_row_ptrs[row] + !is_lower; - size_type new_end = new_row_ptrs[row + 1] - is_lower; - size_type begin = row_ptrs[row] + !is_lower; - size_type end = row_ptrs[row + 1] - is_lower; - size_type count{}; + // lower triangular matrix has the diagonal last + size_type out_nz = new_row_ptrs[row]; + size_type begin = row_ptrs[row]; + size_type end = row_ptrs[row + 1]; + size_type diag_pos = row_ptrs[row + 1] - 1; for (auto nz = begin; nz < end; ++nz) { - if (abs(vals[nz]) >= threshold) { - new_col_idxs[new_begin + count] = col_idxs[nz]; - new_vals[new_begin + count] = vals[nz]; - ++count; + if (abs(vals[nz]) >= threshold || nz == diag_pos) { + new_row_idxs[out_nz] = row; + new_col_idxs[out_nz] = col_idxs[nz]; + new_vals[out_nz] = vals[nz]; + ++out_nz; } } - // add diagonal - auto in_diag = is_lower ? end : begin - 1; - auto out_diag = is_lower ? new_end : new_begin - 1; - new_col_idxs[out_diag] = col_idxs[in_diag]; - new_vals[out_diag] = vals[in_diag]; } } diff --git a/omp/test/factorization/par_ilut_kernels.cpp b/omp/test/factorization/par_ilut_kernels.cpp index 56fd92b4b6b..1e42ea35e4b 100644 --- a/omp/test/factorization/par_ilut_kernels.cpp +++ b/omp/test/factorization/par_ilut_kernels.cpp @@ -130,44 +130,49 @@ class ParIlut : public ::testing::Test { template void test_select(const std::unique_ptr &mtx, - const std::unique_ptr &dmtx, index_type rank) + const std::unique_ptr &dmtx, index_type rank, + value_type tolerance = 0.0) { auto size = index_type(mtx->get_num_stored_elements()); + using ValueType = typename Mtx::value_type; - auto res = - gko::kernels::reference::par_ilut_factorization::threshold_select( - ref, mtx->get_const_values(), size, rank); - auto dres = gko::kernels::omp::par_ilut_factorization::threshold_select( - omp, dmtx->get_const_values(), size, rank); + gko::remove_complex res{}; + gko::remove_complex dres{}; + gko::Array tmp(ref); + gko::Array> tmp2(ref); + gko::Array dtmp(omp); + gko::Array> dtmp2(omp); + + gko::kernels::reference::par_ilut_factorization::threshold_select( + ref, mtx.get(), rank, tmp, tmp2, res); + gko::kernels::omp::par_ilut_factorization::threshold_select( + omp, dmtx.get(), rank, dtmp, dtmp2, dres); ASSERT_EQ(res, dres); } - template + template > void test_filter(const std::unique_ptr &mtx, - const std::unique_ptr &dmtx, value_type threshold, - bool lower) + const std::unique_ptr &dmtx, value_type threshold) { - gko::Array new_row_ptrs(ref); - gko::Array new_col_idxs(ref); - gko::Array new_vals(ref); - gko::Array dnew_row_ptrs(omp); - gko::Array dnew_col_idxs(omp); - gko::Array dnew_vals(omp); + auto res = Mtx::create(ref, mtx_size); + auto dres = Mtx::create(omp, mtx_size); + auto res_coo = Coo::create(ref, mtx_size); + auto dres_coo = Coo::create(omp, mtx_size); gko::kernels::reference::par_ilut_factorization::threshold_filter( - ref, mtx.get(), threshold, new_row_ptrs, new_col_idxs, new_vals, - lower); + ref, mtx.get(), threshold, res.get(), res_coo.get()); gko::kernels::omp::par_ilut_factorization::threshold_filter( - omp, dmtx.get(), threshold, dnew_row_ptrs, dnew_col_idxs, dnew_vals, - lower); - auto res = - Mtx::create(ref, mtx_size, new_vals, new_col_idxs, new_row_ptrs); - auto dres = - Mtx::create(omp, mtx_size, dnew_vals, dnew_col_idxs, dnew_row_ptrs); + omp, dmtx.get(), threshold, dres.get(), dres_coo.get()); GKO_ASSERT_MTX_NEAR(res, dres, 0); GKO_ASSERT_MTX_EQ_SPARSITY(res, dres); + GKO_ASSERT_MTX_NEAR(res, res_coo, 0); + GKO_ASSERT_MTX_EQ_SPARSITY(res, res_coo); + GKO_ASSERT_MTX_NEAR(dres, dres_coo, 0); + GKO_ASSERT_MTX_EQ_SPARSITY(dres, dres_coo); } std::shared_ptr ref; @@ -221,7 +226,7 @@ TEST_F(ParIlut, KernelThresholdSelectMaxIsEquivalentToRef) TEST_F(ParIlut, KernelComplexThresholdSelectIsEquivalentToRef) { test_select(mtx_l_complex, dmtx_l_complex, - mtx_l->get_num_stored_elements() / 3); + mtx_l_complex->get_num_stored_elements() / 3); } @@ -234,79 +239,43 @@ TEST_F(ParIlut, KernelComplexThresholdSelectMinIsEquivalentToRef) TEST_F(ParIlut, KernelComplexThresholdSelectMaxLowerIsEquivalentToRef) { test_select(mtx_l_complex, dmtx_l_complex, - mtx_l->get_num_stored_elements() - 1); + mtx_l_complex->get_num_stored_elements() - 1); } TEST_F(ParIlut, KernelThresholdFilterLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 0.5, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 0.5, false); + test_filter(mtx_l, dmtx_l, 0.5); } TEST_F(ParIlut, KernelThresholdFilterNoneLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 0, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterNoneUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 0, false); + test_filter(mtx_l, dmtx_l, 0); } TEST_F(ParIlut, KernelThresholdFilterAllLowerIsEquivalentToRef) { - test_filter(mtx_l, dmtx_l, 1e6, true); -} - - -TEST_F(ParIlut, KernelThresholdFilterAllUpperIsEquivalentToRef) -{ - test_filter(mtx_u, dmtx_u, 1e6, false); + test_filter(mtx_l, dmtx_l, 1e6); } TEST_F(ParIlut, KernelComplexThresholdFilterLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 0.5, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 0.5, false); + test_filter(mtx_l_complex, dmtx_l_complex, 0.5); } TEST_F(ParIlut, KernelComplexThresholdFilterNoneLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 0, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterNoneUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 0, false); + test_filter(mtx_l_complex, dmtx_l_complex, 0); } TEST_F(ParIlut, KernelComplexThresholdFilterAllLowerIsEquivalentToRef) { - test_filter(mtx_l_complex, dmtx_l_complex, 1e6, true); -} - - -TEST_F(ParIlut, KernelComplexThresholdFilterAllUpperIsEquivalentToRef) -{ - test_filter(mtx_u_complex, dmtx_u_complex, 1e6, false); + test_filter(mtx_l_complex, dmtx_l_complex, 1e6); } diff --git a/reference/factorization/par_ilut_kernels.cpp b/reference/factorization/par_ilut_kernels.cpp index f5ef413314d..2a23c48fb90 100644 --- a/reference/factorization/par_ilut_kernels.cpp +++ b/reference/factorization/par_ilut_kernels.cpp @@ -44,6 +44,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/components/prefix_sum.hpp" +#include "core/matrix/coo_builder.hpp" +#include "core/matrix/csr_builder.hpp" + + namespace gko { namespace kernels { namespace reference { @@ -56,19 +61,23 @@ namespace par_ilut_factorization { template -remove_complex threshold_select( - std::shared_ptr exec, const ValueType *values, - IndexType size, IndexType rank) +void threshold_select(std::shared_ptr exec, + const matrix::Csr *m, + IndexType rank, Array &tmp, + Array> &, + remove_complex &threshold) { - Array data(exec, size); - std::copy_n(values, size, data.get_data()); + auto values = m->get_const_values(); + IndexType size = m->get_num_stored_elements(); + tmp.resize_and_reset(size); + std::copy_n(values, size, tmp.get_data()); - auto begin = data.get_data(); + auto begin = tmp.get_data(); auto target = begin + rank; auto end = begin + size; std::nth_element(begin, target, end, [](ValueType a, ValueType b) { return abs(a) < abs(b); }); - return abs(*target); + threshold = abs(*target); } @@ -80,9 +89,8 @@ template void threshold_filter(std::shared_ptr exec, const matrix::Csr *a, remove_complex threshold, - Array &new_row_ptrs_array, - Array &new_col_idxs_array, - Array &new_vals_array, bool /* is_lower */) + matrix::Csr *m_out, + matrix::Coo *m_out_coo) { auto num_rows = a->get_size()[0]; auto row_ptrs = a->get_const_row_ptrs(); @@ -90,34 +98,41 @@ void threshold_filter(std::shared_ptr exec, auto vals = a->get_const_values(); // first sweep: count nnz for each row - new_row_ptrs_array.resize_and_reset(num_rows + 1); - auto new_row_ptrs = new_row_ptrs_array.get_data(); + auto new_row_ptrs = m_out->get_row_ptrs(); for (size_type row = 0; row < num_rows; ++row) { IndexType count{}; for (size_type nz = row_ptrs[row]; nz < size_type(row_ptrs[row + 1]); ++nz) { count += abs(vals[nz]) >= threshold || col_idxs[nz] == row; } - new_row_ptrs[row + 1] = count; + new_row_ptrs[row] = count; } - // build row pointers: exclusive scan (thus the + 1) - new_row_ptrs[0] = 0; - std::partial_sum(new_row_ptrs + 1, new_row_ptrs + num_rows + 1, - new_row_ptrs + 1); + // build row pointers + prefix_sum(exec, new_row_ptrs, num_rows + 1); // second sweep: accumulate non-zeros auto new_nnz = new_row_ptrs[num_rows]; - new_col_idxs_array.resize_and_reset(new_nnz); - new_vals_array.resize_and_reset(new_nnz); - auto new_col_idxs = new_col_idxs_array.get_data(); - auto new_vals = new_vals_array.get_data(); + // resize arrays and update aliases + matrix::CsrBuilder builder{m_out}; + builder.get_col_idx_array().resize_and_reset(new_nnz); + builder.get_value_array().resize_and_reset(new_nnz); + auto new_col_idxs = m_out->get_col_idxs(); + auto new_vals = m_out->get_values(); + matrix::CooBuilder coo_builder{m_out_coo}; + coo_builder.get_row_idx_array().resize_and_reset(new_nnz); + coo_builder.get_col_idx_array() = + Array::view(exec, new_nnz, new_col_idxs); + coo_builder.get_value_array() = + Array::view(exec, new_nnz, new_vals); + auto new_row_idxs = m_out_coo->get_row_idxs(); for (size_type row = 0; row < num_rows; ++row) { auto new_nz = new_row_ptrs[row]; for (size_type nz = row_ptrs[row]; nz < size_type(row_ptrs[row + 1]); ++nz) { if (abs(vals[nz]) >= threshold || col_idxs[nz] == row) { + new_row_idxs[new_nz] = row; new_col_idxs[new_nz] = col_idxs[nz]; new_vals[new_nz] = vals[nz]; ++new_nz; diff --git a/reference/test/factorization/par_ilut_kernels.cpp b/reference/test/factorization/par_ilut_kernels.cpp index 55be3c35636..494fd38ed7f 100644 --- a/reference/test/factorization/par_ilut_kernels.cpp +++ b/reference/test/factorization/par_ilut_kernels.cpp @@ -68,32 +68,32 @@ class ParIlut : public ::testing::Test { ref(gko::ReferenceExecutor::create()), exec(std::static_pointer_cast(ref)), - mtx1(gko::initialize({{.1, .1, -1., -2.}, - {0., .1, -2., -3.}, - {0., 0., -1., -1.}, - {0., 0., 0., 1.}}, + mtx1(gko::initialize({{.1, 0., 0., 0.}, + {.1, .1, 0., 0.}, + {-1., -2., -1., 0.}, + {-2., -3., -1., 1.}}, ref)), - mtx1_expect_thrm2(gko::initialize({{.1, 0., 0., -2.}, - {0., .1, -2., -3.}, - {0., 0., -1., 0.}, - {0., 0., 0., 1.}}, + mtx1_expect_thrm2(gko::initialize({{.1, 0., 0., 0.}, + {0., .1, 0., 0.}, + {0., -2., -1., 0.}, + {-2., -3., 0., 1.}}, ref)), mtx1_expect_thrm3(gko::initialize({{.1, 0., 0., 0.}, - {0., .1, 0., -3.}, + {0., .1, 0., 0.}, {0., 0., -1., 0.}, - {0., 0., 0., 1.}}, + {0., -3., 0., 1.}}, ref)), mtx1_complex(gko::initialize( - {{.1 + 0. * i, -1. + .1 * i, -1. + i, 1. - 2. * i}, - {0. * i, .1 - i, -2. + .2 * i, -3. - .1 * i}, - {0. * i, 0. * i, -1. - .3 * i, -1. + .1 * i}, - {0. * i, 0. * i, 0. * i, .1 + 2. * i}}, + {{.1 + 0. * i, 0. * i, 0. * i, 0. * i}, + {-1. + .1 * i, .1 - i, 0. * i, 0. * i}, + {-1. + i, -2. + .2 * i, -1. - .3 * i, 0. * i}, + {1. - 2. * i, -3. - .1 * i, -1. + .1 * i, .1 + 2. * i}}, ref)), mtx1_expect_complex_thrm(gko::initialize( - {{.1 + 0. * i, 0. * i, -1. + i, 1. - 2. * i}, - {0. * i, .1 - i, -2. + .2 * i, -3. - .1 * i}, - {0. * i, 0. * i, -1. - .3 * i, 0. * i}, - {0. * i, 0. * i, 0. * i, .1 + 2. * i}}, + {{.1 + 0. * i, 0. * i, 0. * i, 0. * i}, + {0. * i, .1 - i, 0. * i, 0. * i}, + {-1. + i, -2. + .2 * i, -1. - .3 * i, 0. * i}, + {1. - 2. * i, -3. - .1 * i, 0. * i, .1 + 2. * i}}, ref)) {} @@ -101,32 +101,35 @@ class ParIlut : public ::testing::Test { void test_select(const std::unique_ptr &mtx, index_type rank, value_type expected, value_type tolerance = 0.0) { - auto vals = mtx->get_const_values(); - auto size = index_type(mtx->get_num_stored_elements()); + using ValueType = typename Mtx::value_type; + gko::remove_complex result{}; - auto result = - gko::kernels::reference::par_ilut_factorization::threshold_select( - ref, vals, size, rank); + gko::remove_complex res{}; + gko::remove_complex dres{}; + gko::Array tmp(ref); + gko::Array> tmp2(ref); + gko::kernels::reference::par_ilut_factorization::threshold_select( + ref, mtx.get(), rank, tmp, tmp2, result); ASSERT_NEAR(result, expected, tolerance); } - template + template > void test_filter(const std::unique_ptr &mtx, value_type threshold, const std::unique_ptr &expected) { - gko::Array new_row_ptrs(exec); - gko::Array new_col_idxs(exec); - gko::Array new_vals(exec); + auto res_mtx = Mtx::create(exec, mtx->get_size()); + auto res_mtx_coo = Coo::create(exec, mtx->get_size()); gko::kernels::reference::par_ilut_factorization::threshold_filter( - ref, mtx.get(), threshold, new_row_ptrs, new_col_idxs, new_vals, - false); - auto res_mtx = Mtx::create(exec, mtx->get_size(), new_vals, - new_col_idxs, new_row_ptrs); + ref, mtx.get(), threshold, res_mtx.get(), res_mtx_coo.get()); GKO_ASSERT_MTX_NEAR(expected, res_mtx, 0); GKO_ASSERT_MTX_EQ_SPARSITY(expected, res_mtx); + GKO_ASSERT_MTX_NEAR(res_mtx, res_mtx_coo, 0); + GKO_ASSERT_MTX_EQ_SPARSITY(res_mtx, res_mtx_coo); } std::complex i;