diff --git a/common/matrix/ell_kernels.hpp.inc b/common/matrix/ell_kernels.hpp.inc index 79ed4ed979b..2323d512258 100644 --- a/common/matrix/ell_kernels.hpp.inc +++ b/common/matrix/ell_kernels.hpp.inc @@ -41,8 +41,8 @@ __device__ void spmv_kernel( const size_type num_rows, const int num_worker_per_row, acc::range val, const IndexType *__restrict__ col, const size_type stride, const size_type num_stored_elements_per_row, - acc::range b, const size_type b_stride, - OutputValueType *__restrict__ c, const size_type c_stride, Closure op) + acc::range b, OutputValueType *__restrict__ c, + const size_type c_stride, Closure op) { const auto tidx = thread::get_thread_id_flat(); const decltype(tidx) column_id = blockIdx.y; @@ -109,12 +109,12 @@ __global__ __launch_bounds__(default_block_size) void spmv( const size_type num_rows, const int num_worker_per_row, acc::range val, const IndexType *__restrict__ col, const size_type stride, const size_type num_stored_elements_per_row, - acc::range b, const size_type b_stride, - OutputValueType *__restrict__ c, const size_type c_stride) + acc::range b, OutputValueType *__restrict__ c, + const size_type c_stride) { spmv_kernel( num_rows, num_worker_per_row, val, col, stride, - num_stored_elements_per_row, b, b_stride, c, c_stride, + num_stored_elements_per_row, b, c, c_stride, [](const OutputValueType &x, const OutputValueType &y) { return x; }); } @@ -126,8 +126,8 @@ __global__ __launch_bounds__(default_block_size) void spmv( acc::range alpha, acc::range val, const IndexType *__restrict__ col, const size_type stride, const size_type num_stored_elements_per_row, acc::range b, - const size_type b_stride, const OutputValueType *__restrict__ beta, - OutputValueType *__restrict__ c, const size_type c_stride) + const OutputValueType *__restrict__ beta, OutputValueType *__restrict__ c, + const size_type c_stride) { const OutputValueType alpha_val = alpha(0); const OutputValueType beta_val = beta[0]; @@ -138,14 +138,14 @@ __global__ __launch_bounds__(default_block_size) void spmv( if (atomic) { spmv_kernel( num_rows, num_worker_per_row, val, col, stride, - num_stored_elements_per_row, b, b_stride, c, c_stride, + num_stored_elements_per_row, b, c, c_stride, [&alpha_val](const OutputValueType &x, const OutputValueType &y) { return alpha_val * x; }); } else { spmv_kernel( num_rows, num_worker_per_row, val, col, stride, - num_stored_elements_per_row, b, b_stride, c, c_stride, + num_stored_elements_per_row, b, c, c_stride, [&alpha_val, &beta_val](const OutputValueType &x, const OutputValueType &y) { return alpha_val * x + beta_val * y; diff --git a/cuda/matrix/ell_kernels.cu b/cuda/matrix/ell_kernels.cu index 57617269ed0..8808a502ebb 100644 --- a/cuda/matrix/ell_kernels.cu +++ b/cuda/matrix/ell_kernels.cu @@ -159,8 +159,8 @@ void abstract_spmv(syn::value_list, int num_worker_per_row, <<>>( nrows, num_worker_per_row, as_cuda_accessor(a_vals), a->get_const_col_idxs(), stride, num_stored_elements_per_row, - as_cuda_accessor(b_vals), b->get_stride(), - as_cuda_type(c->get_values()), c->get_stride()); + as_cuda_accessor(b_vals), as_cuda_type(c->get_values()), + c->get_stride()); } else if (alpha != nullptr && beta != nullptr) { const auto alpha_val = gko::acc::range( std::array{1}, alpha->get_const_values()); @@ -169,7 +169,7 @@ void abstract_spmv(syn::value_list, int num_worker_per_row, nrows, num_worker_per_row, as_cuda_accessor(alpha_val), as_cuda_accessor(a_vals), a->get_const_col_idxs(), stride, num_stored_elements_per_row, as_cuda_accessor(b_vals), - b->get_stride(), as_cuda_type(beta->get_const_values()), + as_cuda_type(beta->get_const_values()), as_cuda_type(c->get_values()), c->get_stride()); } else { GKO_KERNEL_NOT_FOUND; diff --git a/dpcpp/matrix/ell_kernels.dp.cpp b/dpcpp/matrix/ell_kernels.dp.cpp index 01a40b5d9b5..b4dfeafeb54 100644 --- a/dpcpp/matrix/ell_kernels.dp.cpp +++ b/dpcpp/matrix/ell_kernels.dp.cpp @@ -56,24 +56,27 @@ namespace dpcpp { namespace ell { -template +template void spmv(std::shared_ptr exec, - const matrix::Ell *a, - const matrix::Dense *b, - matrix::Dense *c) GKO_NOT_IMPLEMENTED; + const matrix::Ell *a, + const matrix::Dense *b, + matrix::Dense *c) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( + GKO_DECLARE_ELL_SPMV_KERNEL); -template +template void advanced_spmv(std::shared_ptr exec, - const matrix::Dense *alpha, - const matrix::Ell *a, - const matrix::Dense *b, - const matrix::Dense *beta, - matrix::Dense *c) GKO_NOT_IMPLEMENTED; + const matrix::Dense *alpha, + const matrix::Ell *a, + const matrix::Dense *b, + const matrix::Dense *beta, + matrix::Dense *c) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL); diff --git a/hip/matrix/ell_kernels.hip.cpp b/hip/matrix/ell_kernels.hip.cpp index f1972635385..fc10fc7d0e1 100644 --- a/hip/matrix/ell_kernels.hip.cpp +++ b/hip/matrix/ell_kernels.hip.cpp @@ -164,7 +164,7 @@ void abstract_spmv(syn::value_list, int num_worker_per_row, dim3(grid_size), dim3(block_size), 0, 0, nrows, num_worker_per_row, as_hip_accessor(a_vals), a->get_const_col_idxs(), stride, num_stored_elements_per_row, as_hip_accessor(b_vals), - b->get_stride(), as_hip_type(c->get_values()), c->get_stride()); + as_hip_type(c->get_values()), c->get_stride()); } else if (alpha != nullptr && beta != nullptr) { const auto alpha_val = gko::acc::range( std::array{1}, alpha->get_const_values()); @@ -173,9 +173,8 @@ void abstract_spmv(syn::value_list, int num_worker_per_row, dim3(grid_size), dim3(block_size), 0, 0, nrows, num_worker_per_row, as_hip_accessor(alpha_val), as_hip_accessor(a_vals), a->get_const_col_idxs(), stride, num_stored_elements_per_row, - as_hip_accessor(b_vals), b->get_stride(), - as_hip_type(beta->get_const_values()), as_hip_type(c->get_values()), - c->get_stride()); + as_hip_accessor(b_vals), as_hip_type(beta->get_const_values()), + as_hip_type(c->get_values()), c->get_stride()); } else { GKO_KERNEL_NOT_FOUND; } diff --git a/omp/matrix/ell_kernels.cpp b/omp/matrix/ell_kernels.cpp index 040176053d6..db7fe3cbc0d 100644 --- a/omp/matrix/ell_kernels.cpp +++ b/omp/matrix/ell_kernels.cpp @@ -77,8 +77,8 @@ void spmv(std::shared_ptr exec, std::array{num_stored_elements_per_row * stride}, a->get_const_values()); const auto b_vals = gko::acc::range( - std::array{num_rows, b->get_stride()}, - b->get_const_values()); + std::array{{b->get_size()[0], b->get_size()[1]}}, + b->get_const_values(), std::array{{b->get_stride()}}); #pragma omp parallel for for (size_type row = 0; row < a->get_size()[0]; row++) { @@ -121,8 +121,8 @@ void advanced_spmv(std::shared_ptr exec, std::array{num_stored_elements_per_row * stride}, a->get_const_values()); const auto b_vals = gko::acc::range( - std::array{num_rows, b->get_stride()}, - b->get_const_values()); + std::array{{b->get_size()[0], b->get_size()[1]}}, + b->get_const_values(), std::array{{b->get_stride()}}); const auto alpha_val = OutputValueType(alpha->at(0, 0)); const auto beta_val = beta->at(0, 0); diff --git a/reference/matrix/ell_kernels.cpp b/reference/matrix/ell_kernels.cpp index 8b0f541799a..58df834bfba 100644 --- a/reference/matrix/ell_kernels.cpp +++ b/reference/matrix/ell_kernels.cpp @@ -73,8 +73,8 @@ void spmv(std::shared_ptr exec, std::array{num_stored_elements_per_row * stride}, a->get_const_values()); const auto b_vals = gko::acc::range( - std::array{num_rows, b->get_stride()}, - b->get_const_values()); + std::array{{b->get_size()[0], b->get_size()[1]}}, + b->get_const_values(), std::array{{b->get_stride()}}); for (size_type row = 0; row < a->get_size()[0]; row++) { for (size_type j = 0; j < c->get_size()[1]; j++) { @@ -116,8 +116,8 @@ void advanced_spmv(std::shared_ptr exec, std::array{num_stored_elements_per_row * stride}, a->get_const_values()); const auto b_vals = gko::acc::range( - std::array{num_rows, b->get_stride()}, - b->get_const_values()); + std::array{{b->get_size()[0], b->get_size()[1]}}, + b->get_const_values(), std::array{{b->get_stride()}}); const auto alpha_val = OutputValueType(alpha->at(0, 0)); const auto beta_val = beta->at(0, 0);