diff --git a/cuda/matrix/csr_kernels.cu b/cuda/matrix/csr_kernels.cu index d4107f980e9..ed721c979fe 100644 --- a/cuda/matrix/csr_kernels.cu +++ b/cuda/matrix/csr_kernels.cu @@ -40,8 +40,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include -#include #include +#include #include "core/matrix/dense_kernels.hpp" @@ -1138,7 +1138,7 @@ __global__ __launch_bounds__(default_block_size) void fill_in_ell( } -} // namespace kernel +} // namespace kernel template @@ -1187,10 +1187,11 @@ __global__ __launch_bounds__(default_block_size) void reduce_max_nnz_per_slice( constexpr auto warp_size = cuda_config::warp_size; const auto warpid = tidx / warp_size; const auto tid_in_warp = tidx % warp_size; + const auto slice_num = ceildiv(num_rows, slice_size); size_type thread_result = 0; for (auto i = tid_in_warp; i < slice_size; i += warp_size) { - if (warpid * warp_size + i < num_rows) { + if (warpid * slice_size + i < num_rows) { thread_result = max(thread_result, nnz_per_row[warpid * slice_size + i]); } @@ -1202,7 +1203,7 @@ __global__ __launch_bounds__(default_block_size) void reduce_max_nnz_per_slice( warp_tile, thread_result, [](const size_type &a, const size_type &b) { return max(a, b); }); - if (tid_in_warp == 0) { + if (tid_in_warp == 0 && warpid < slice_num) { result[warpid] = ceildiv(warp_result, stride_factor) * stride_factor; } } @@ -1250,6 +1251,7 @@ void calculate_total_cols(std::shared_ptr exec, as_cuda_type(nnz_per_row.get_const_data()), as_cuda_type(max_nnz_per_slice.get_data())); + grid_dim = ceildiv(slice_num, default_block_size); auto block_results = Array(exec, grid_dim); kernel::reduce_total_cols<< exec, auto max_nnz_per_slice = Array(exec, slice_num); - const auto grid_dim = ceildiv(slice_num, default_block_size); + auto grid_dim = + ceildiv(slice_num * cuda_config::warp_size, default_block_size); kernel::reduce_max_nnz_per_slice<<>>( num_rows, slice_size, stride_factor, as_cuda_type(nnz_per_row.get_const_data()), as_cuda_type(max_nnz_per_slice.get_data())); + grid_dim = ceildiv(slice_num, default_block_size); auto block_results = Array(exec, grid_dim); kernel::reduce_total_cols<<