From 7539bb332aa3bb365d76fcffe452f48b2821b3fe Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Mon, 23 Oct 2023 21:17:00 +0200 Subject: [PATCH] refine the kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Pratik Nayak Co-authored-by: Thomas Grützmacher --- common/cuda_hip/matrix/csr_common.hpp.inc | 1 - common/cuda_hip/matrix/csr_kernels.hpp.inc | 12 ++++++++---- dpcpp/matrix/csr_kernels.dp.cpp | 14 ++++++++------ omp/matrix/csr_kernels.cpp | 11 ++++++++--- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/common/cuda_hip/matrix/csr_common.hpp.inc b/common/cuda_hip/matrix/csr_common.hpp.inc index 0fce02aecfa..35718464c42 100644 --- a/common/cuda_hip/matrix/csr_common.hpp.inc +++ b/common/cuda_hip/matrix/csr_common.hpp.inc @@ -102,7 +102,6 @@ __global__ __launch_bounds__(default_block_size) void check_diagonal_entries( if (tile_grp.thread_rank() == 0) { *has_all_diags = false; } - return; } } } diff --git a/common/cuda_hip/matrix/csr_kernels.hpp.inc b/common/cuda_hip/matrix/csr_kernels.hpp.inc index 3f02337747e..6a6c590f540 100644 --- a/common/cuda_hip/matrix/csr_kernels.hpp.inc +++ b/common/cuda_hip/matrix/csr_kernels.hpp.inc @@ -826,15 +826,19 @@ __global__ __launch_bounds__(default_block_size) void add_scaled_identity( auto tile_grp = group::tiled_partition(group::this_thread_block()); const auto warpid = thread::get_subwarp_id_flat(); - const auto num_warps = thread::get_subwarp_num_flat(); if (warpid < num_rows) { const auto tid_in_warp = tile_grp.thread_rank(); const IndexType row_start = row_ptrs[warpid]; const IndexType num_nz = row_ptrs[warpid + 1] - row_start; + const auto beta_val = beta[0]; + const auto alpha_val = alpha[0]; for (IndexType iz = tid_in_warp; iz < num_nz; iz += warp_size) { - values[iz + row_start] *= beta[0]; - if (col_idxs[iz + row_start] == warpid) { - values[iz + row_start] += alpha[0]; + if (beta[0] != one()) { + values[iz + row_start] *= beta_val; + } + if (col_idxs[iz + row_start] == warpid && + alpha_val != zero()) { + values[iz + row_start] += alpha_val; } } } diff --git a/dpcpp/matrix/csr_kernels.dp.cpp b/dpcpp/matrix/csr_kernels.dp.cpp index c5a8e3ef4d4..915e2027a26 100644 --- a/dpcpp/matrix/csr_kernels.dp.cpp +++ b/dpcpp/matrix/csr_kernels.dp.cpp @@ -899,7 +899,6 @@ void check_diagonal_entries(const IndexType num_min_rows_cols, if (tile_grp.thread_rank() == 0) { *has_all_diags = false; } - return; } } } @@ -921,16 +920,19 @@ void add_scaled_identity(const ValueType* const __restrict__ alpha, group::this_thread_block(item_ct1)); const auto row = thread::get_subwarp_id_flat(item_ct1); - const auto num_warps = - thread::get_subwarp_num_flat(item_ct1); if (row < num_rows) { const auto tid_in_warp = tile_grp.thread_rank(); const auto row_start = row_ptrs[row]; const auto num_nz = row_ptrs[row + 1] - row_start; + const auto beta_val = beta[0]; + const auto alpha_val = alpha[0]; for (IndexType iz = tid_in_warp; iz < num_nz; iz += subgroup_size) { - values[iz + row_start] *= beta[0]; - if (col_idxs[iz + row_start] == row) { - values[iz + row_start] += alpha[0]; + if (beta_val != one()) { + values[iz + row_start] *= beta_val; + } + if (col_idxs[iz + row_start] == row && + alpha_val != zero()) { + values[iz + row_start] += alpha_val; } } } diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 7d4a5a7ebd1..1757b4b8a25 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -1134,12 +1134,17 @@ void add_scaled_identity(std::shared_ptr exec, const auto nrows = static_cast(mtx->get_size()[0]); const auto row_ptrs = mtx->get_const_row_ptrs(); const auto vals = mtx->get_values(); + const auto beta_val = beta->get_const_values()[0]; + const auto alpha_val = alpha->get_const_values()[0]; #pragma omp parallel for for (IndexType row = 0; row < nrows; row++) { for (IndexType iz = row_ptrs[row]; iz < row_ptrs[row + 1]; iz++) { - vals[iz] *= beta->get_const_values()[0]; - if (row == mtx->get_const_col_idxs()[iz]) { - vals[iz] += alpha->get_const_values()[0]; + if (beta_val != one()) { + vals[iz] *= beta_val; + } + if (row == mtx->get_const_col_idxs()[iz] && + alpha_val != zero()) { + vals[iz] += alpha_val; } } }