diff --git a/core/config/config_helper.hpp b/core/config/config_helper.hpp index f84e6799bf7..c99997f3188 100644 --- a/core/config/config_helper.hpp +++ b/core/config/config_helper.hpp @@ -199,7 +199,9 @@ get_value(const pnode& config) * This is specialization for floating point type */ template -inline std::enable_if_t::value, ValueType> +inline std::enable_if_t::value || + std::is_same::value, + ValueType> get_value(const pnode& config) { auto val = config.get_real(); diff --git a/core/test/utils.hpp b/core/test/utils.hpp index 6c0ce1f7e0a..f6a0bc49d31 100644 --- a/core/test/utils.hpp +++ b/core/test/utils.hpp @@ -161,7 +161,9 @@ using ComplexValueIndexTypes = ::testing::Types) gko::int32>, std::tuple, gko::int32>, OPTIONAL(std::tuple, gko::int64>) - std::tuple, gko::int64>> ; + std::tuple < std::complex, + gko::int64 >> + ; #else ::testing::Types, gko::int32>) std::tuple, gko::int32>, @@ -317,7 +319,7 @@ struct TupleTypenameNameGenerator { }; -namespace detail { +namespace temporary_test { // singly linked list of all our supported precisions @@ -346,10 +348,10 @@ struct next_precision_impl> { }; -} // namespace detail +} // namespace temporary_test template -using next_precision = typename detail::next_precision_impl::type; +using next_precision = typename temporary_test::next_precision_impl::type; #define SKIP_IF_HALF(type) \ diff --git a/cuda/solver/batch_cg_kernels.cu b/cuda/solver/batch_cg_kernels.cu index cff72652629..c7ffa0394ac 100644 --- a/cuda/solver/batch_cg_kernels.cu +++ b/cuda/solver/batch_cg_kernels.cu @@ -100,13 +100,14 @@ template using settings = gko::kernels::batch_cg::settings; -template +template class kernel_caller { public: - using value_type = CuValueType; + using cu_value_type = cuda_type; + ; kernel_caller(std::shared_ptr exec, - const settings> settings) + const settings> settings) : exec_{std::move(exec)}, settings_{settings} {} @@ -116,36 +117,36 @@ public: void launch_apply_kernel( const gko::kernels::batch_cg::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType& mat, - const value_type* const __restrict__ b_values, - value_type* const __restrict__ x_values, - value_type* const __restrict__ workspace_data, const int& block_size, + const cu_value_type* const __restrict__ b_values, + cu_value_type* const __restrict__ x_values, + cu_value_type* const __restrict__ workspace_data, const int& block_size, const size_t& shared_size) const { apply_kernel <<get_stream()>>>(sconf, settings_.max_iterations, - settings_.residual_tol, logger, prec, mat, - b_values, x_values, workspace_data); + as_cuda_type(settings_.residual_tol), + logger, prec, mat, b_values, x_values, + workspace_data); } template void call_kernel( LogType logger, const BatchMatrixType& mat, PrecType prec, - const gko::batch::multi_vector::uniform_batch& b, - const gko::batch::multi_vector::uniform_batch& x) const + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const { - using real_type = gko::remove_complex; + using real_type = gko::remove_complex; const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; const int padded_num_rows = ceildiv(mat.num_rows, align_multiple) * align_multiple; - const int shmem_per_blk = - get_max_dynamic_shared_memory(exec_); + const int shmem_per_blk = get_max_dynamic_shared_memory< + StopType, PrecType, LogType, BatchMatrixType, cu_value_type>(exec_); const int block_size = get_num_threads_per_block( + BatchMatrixType, cu_value_type>( exec_, mat.num_rows); GKO_ASSERT(block_size >= 2 * config::warp_size); @@ -153,18 +154,18 @@ public: padded_num_rows, mat.get_single_item_num_nnz()); const auto sconf = gko::kernels::batch_cg::compute_shared_storage( + cu_value_type>( shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), b.num_rhs); const size_t shared_size = - sconf.n_shared * padded_num_rows * sizeof(value_type) + + sconf.n_shared * padded_num_rows * sizeof(cu_value_type) + (sconf.prec_shared ? prec_size : 0); - auto workspace = gko::array( + auto workspace = gko::array( exec_, - sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); - GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(value_type) == 0); + sconf.gmem_stride_bytes * num_batch_items / sizeof(cu_value_type)); + GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(cu_value_type) == 0); - value_type* const workspace_data = workspace.get_data(); + cu_value_type* const workspace_data = workspace.get_data(); // Template parameters launch_apply_kernel @@ -212,7 +213,7 @@ public: private: std::shared_ptr exec_; - const settings> settings_; + const settings> settings_; }; @@ -225,9 +226,8 @@ void apply(std::shared_ptr exec, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) { - using cu_value_type = cuda_type; auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } diff --git a/hip/solver/batch_cg_kernels.hip.cpp b/hip/solver/batch_cg_kernels.hip.cpp index 450d02a302c..9d9d0a52df6 100644 --- a/hip/solver/batch_cg_kernels.hip.cpp +++ b/hip/solver/batch_cg_kernels.hip.cpp @@ -79,13 +79,13 @@ template using settings = gko::kernels::batch_cg::settings; -template +template class kernel_caller { public: - using value_type = HipValueType; + using hip_value_type = hip_type; kernel_caller(std::shared_ptr exec, - const settings> settings) + const settings> settings) : exec_{exec}, settings_{settings} {} @@ -96,15 +96,16 @@ class kernel_caller { const gko::kernels::batch_cg::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType& mat, const value_type* const __restrict__ b_values, - value_type* const __restrict__ x_values, - value_type* const __restrict__ workspace_data, const int& block_size, - const size_t& shared_size) const + hip_value_type* const __restrict__ x_values, + hip_value_type* const __restrict__ workspace_data, + const int& block_size, const size_t& shared_size) const { apply_kernel <<get_stream()>>>(sconf, settings_.max_iterations, - settings_.residual_tol, logger, prec, mat, - b_values, x_values, workspace_data); + as_hip_type(settings_.residual_tol), + logger, prec, mat, b_values, x_values, + workspace_data); } @@ -112,10 +113,10 @@ class kernel_caller { typename LogType> void call_kernel( LogType logger, const BatchMatrixType& mat, PrecType prec, - const gko::batch::multi_vector::uniform_batch& b, - const gko::batch::multi_vector::uniform_batch& x) const + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const { - using real_type = gko::remove_complex; + using real_type = gko::remove_complex; const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; const int padded_num_rows = @@ -134,18 +135,18 @@ class kernel_caller { padded_num_rows, mat.get_single_item_num_nnz()); const auto sconf = gko::kernels::batch_cg::compute_shared_storage( + hip_value_type>( shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), b.num_rhs); const size_t shared_size = - sconf.n_shared * padded_num_rows * sizeof(value_type) + + sconf.n_shared * padded_num_rows * sizeof(hip_value_type) + (sconf.prec_shared ? prec_size : 0); - auto workspace = gko::array( + auto workspace = gko::array( exec_, - sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); - GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(value_type) == 0); + sconf.gmem_stride_bytes * num_batch_items / sizeof(hip_value_type)); + GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(hip_value_type) == 0); - value_type* const workspace_data = workspace.get_data(); + hip_value_type* const workspace_data = workspace.get_data(); // Template parameters launch_apply_kernel exec_; - const settings> settings_; + const settings> settings_; }; @@ -206,9 +207,8 @@ void apply(std::shared_ptr exec, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) { - using hip_value_type = hip_type; auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); }