Skip to content

Commit

Permalink
refine the kernel
Browse files Browse the repository at this point in the history
Co-authored-by: Pratik Nayak <pratikvn@protonmail.com>
Co-authored-by: Thomas Grützmacher <thomas.gruetzmacher@kit.edu>
  • Loading branch information
3 people committed Oct 24, 2023
1 parent 3c8191c commit a1225b5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
1 change: 0 additions & 1 deletion common/cuda_hip/matrix/csr_common.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ __global__ __launch_bounds__(default_block_size) void check_diagonal_entries(
if (tile_grp.thread_rank() == 0) {
*has_all_diags = false;
}
return;
}
}
}
Expand Down
12 changes: 8 additions & 4 deletions common/cuda_hip/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -826,15 +826,19 @@ __global__ __launch_bounds__(default_block_size) void add_scaled_identity(
auto tile_grp =
group::tiled_partition<warp_size>(group::this_thread_block());
const auto warpid = thread::get_subwarp_id_flat<warp_size, IndexType>();
const auto num_warps = thread::get_subwarp_num_flat<warp_size, IndexType>();
if (warpid < num_rows) {
const auto tid_in_warp = tile_grp.thread_rank();
const IndexType row_start = row_ptrs[warpid];
const IndexType num_nz = row_ptrs[warpid + 1] - row_start;
const auto beta_val = beta[0];
const auto alpha_val = alpha[0];
for (IndexType iz = tid_in_warp; iz < num_nz; iz += warp_size) {
values[iz + row_start] *= beta[0];
if (col_idxs[iz + row_start] == warpid) {
values[iz + row_start] += alpha[0];
if (beta_val != one<ValueType>()) {
values[iz + row_start] *= beta_val;
}
if (col_idxs[iz + row_start] == warpid &&
alpha_val != zero<ValueType>()) {
values[iz + row_start] += alpha_val;
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,6 @@ void check_diagonal_entries(const IndexType num_min_rows_cols,
if (tile_grp.thread_rank() == 0) {
*has_all_diags = false;
}
return;
}
}
}
Expand All @@ -921,16 +920,19 @@ void add_scaled_identity(const ValueType* const __restrict__ alpha,
group::this_thread_block(item_ct1));
const auto row =
thread::get_subwarp_id_flat<subgroup_size, IndexType>(item_ct1);
const auto num_warps =
thread::get_subwarp_num_flat<subgroup_size, IndexType>(item_ct1);
if (row < num_rows) {
const auto tid_in_warp = tile_grp.thread_rank();
const auto row_start = row_ptrs[row];
const auto num_nz = row_ptrs[row + 1] - row_start;
const auto beta_val = beta[0];
const auto alpha_val = alpha[0];
for (IndexType iz = tid_in_warp; iz < num_nz; iz += subgroup_size) {
values[iz + row_start] *= beta[0];
if (col_idxs[iz + row_start] == row) {
values[iz + row_start] += alpha[0];
if (beta_val != one<ValueType>()) {
values[iz + row_start] *= beta_val;
}
if (col_idxs[iz + row_start] == row &&
alpha_val != zero<ValueType>()) {
values[iz + row_start] += alpha_val;
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,12 +1134,17 @@ void add_scaled_identity(std::shared_ptr<const OmpExecutor> exec,
const auto nrows = static_cast<IndexType>(mtx->get_size()[0]);
const auto row_ptrs = mtx->get_const_row_ptrs();
const auto vals = mtx->get_values();
const auto beta_val = beta->get_const_values()[0];
const auto alpha_val = alpha->get_const_values()[0];
#pragma omp parallel for
for (IndexType row = 0; row < nrows; row++) {
for (IndexType iz = row_ptrs[row]; iz < row_ptrs[row + 1]; iz++) {
vals[iz] *= beta->get_const_values()[0];
if (row == mtx->get_const_col_idxs()[iz]) {
vals[iz] += alpha->get_const_values()[0];
if (beta_val != one<ValueType>()) {
vals[iz] *= beta_val;
}
if (row == mtx->get_const_col_idxs()[iz] &&
alpha_val != zero<ValueType>()) {
vals[iz] += alpha_val;
}
}
}
Expand Down

0 comments on commit a1225b5

Please sign in to comment.