Skip to content

Commit

Permalink
fix config, ambiguous namespace, and batch
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jul 3, 2024
1 parent 8def1b5 commit 390bd02
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 49 deletions.
4 changes: 3 additions & 1 deletion core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ get_value(const pnode& config)
* This is specialization for floating point type
*/
template <typename ValueType>
inline std::enable_if_t<std::is_floating_point<ValueType>::value, ValueType>
inline std::enable_if_t<std::is_floating_point<ValueType>::value ||
std::is_same<ValueType, half>::value,
ValueType>
get_value(const pnode& config)
{
auto val = config.get_real();
Expand Down
10 changes: 6 additions & 4 deletions core/test/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ using ComplexValueIndexTypes =
::testing::Types<OPTIONAL(std::tuple < std::complex<gko::half>) gko::int32>,
std::tuple<std::complex<float>, gko::int32>,
OPTIONAL(std::tuple<std::complex<gko::half>, gko::int64>)
std::tuple<std::complex<float>, gko::int64>> ;
std::tuple < std::complex<float>,
gko::int64 >>
;
#else
::testing::Types<OPTIONAL(std::tuple<std::complex<gko::half>, gko::int32>)
std::tuple<std::complex<float>, gko::int32>,
Expand Down Expand Up @@ -317,7 +319,7 @@ struct TupleTypenameNameGenerator {
};


namespace detail {
namespace temporary_test {


// singly linked list of all our supported precisions
Expand Down Expand Up @@ -346,10 +348,10 @@ struct next_precision_impl<std::complex<T>> {
};


} // namespace detail
} // namespace temporary_test

template <typename T>
using next_precision = typename detail::next_precision_impl<T>::type;
using next_precision = typename temporary_test::next_precision_impl<T>::type;


#define SKIP_IF_HALF(type) \
Expand Down
48 changes: 24 additions & 24 deletions cuda/solver/batch_cg_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,14 @@ template <typename T>
using settings = gko::kernels::batch_cg::settings<T>;


template <typename CuValueType>
template <typename ValueType>
class kernel_caller {
public:
using value_type = CuValueType;
using cu_value_type = cuda_type<ValueType>;
;

kernel_caller(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<value_type>> settings)
const settings<remove_complex<ValueType>> settings)
: exec_{std::move(exec)}, settings_{settings}
{}

Expand All @@ -116,55 +117,55 @@ public:
void launch_apply_kernel(
const gko::kernels::batch_cg::storage_config& sconf, LogType& logger,
PrecType& prec, const BatchMatrixType& mat,
const value_type* const __restrict__ b_values,
value_type* const __restrict__ x_values,
value_type* const __restrict__ workspace_data, const int& block_size,
const cu_value_type* const __restrict__ b_values,
cu_value_type* const __restrict__ x_values,
cu_value_type* const __restrict__ workspace_data, const int& block_size,
const size_t& shared_size) const
{
apply_kernel<StopType, n_shared, prec_shared_bool>
<<<mat.num_batch_items, block_size, shared_size,
exec_->get_stream()>>>(sconf, settings_.max_iterations,
settings_.residual_tol, logger, prec, mat,
b_values, x_values, workspace_data);
as_cuda_type(settings_.residual_tol),
logger, prec, mat, b_values, x_values,
workspace_data);
}
template <typename BatchMatrixType, typename PrecType, typename StopType,
typename LogType>
void call_kernel(
LogType logger, const BatchMatrixType& mat, PrecType prec,
const gko::batch::multi_vector::uniform_batch<const value_type>& b,
const gko::batch::multi_vector::uniform_batch<value_type>& x) const
const gko::batch::multi_vector::uniform_batch<const cu_value_type>& b,
const gko::batch::multi_vector::uniform_batch<cu_value_type>& x) const
{
using real_type = gko::remove_complex<value_type>;
using real_type = gko::remove_complex<cu_value_type>;
const size_type num_batch_items = mat.num_batch_items;
constexpr int align_multiple = 8;
const int padded_num_rows =
ceildiv(mat.num_rows, align_multiple) * align_multiple;
const int shmem_per_blk =
get_max_dynamic_shared_memory<StopType, PrecType, LogType,
BatchMatrixType, value_type>(exec_);
const int shmem_per_blk = get_max_dynamic_shared_memory<
StopType, PrecType, LogType, BatchMatrixType, cu_value_type>(exec_);
const int block_size =
get_num_threads_per_block<StopType, PrecType, LogType,
BatchMatrixType, value_type>(
BatchMatrixType, cu_value_type>(
exec_, mat.num_rows);
GKO_ASSERT(block_size >= 2 * config::warp_size);
const size_t prec_size = PrecType::dynamic_work_size(
padded_num_rows, mat.get_single_item_num_nnz());
const auto sconf =
gko::kernels::batch_cg::compute_shared_storage<PrecType,
value_type>(
cu_value_type>(
shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(),
b.num_rhs);
const size_t shared_size =
sconf.n_shared * padded_num_rows * sizeof(value_type) +
sconf.n_shared * padded_num_rows * sizeof(cu_value_type) +
(sconf.prec_shared ? prec_size : 0);
auto workspace = gko::array<value_type>(
auto workspace = gko::array<cu_value_type>(
exec_,
sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type));
GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(value_type) == 0);
sconf.gmem_stride_bytes * num_batch_items / sizeof(cu_value_type));
GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(cu_value_type) == 0);
value_type* const workspace_data = workspace.get_data();
cu_value_type* const workspace_data = workspace.get_data();
// Template parameters launch_apply_kernel<StopType, n_shared,
// prec_shared>
Expand Down Expand Up @@ -212,7 +213,7 @@ public:
private:
std::shared_ptr<const DefaultExecutor> exec_;
const settings<remove_complex<value_type>> settings_;
const settings<remove_complex<ValueType>> settings_;
};
Expand All @@ -225,9 +226,8 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
batch::MultiVector<ValueType>* const x,
batch::log::detail::log_data<remove_complex<ValueType>>& logdata)
{
using cu_value_type = cuda_type<ValueType>;
auto dispatcher = batch::solver::create_dispatcher<ValueType>(
kernel_caller<cu_value_type>(exec, settings), settings, mat, precon);
kernel_caller<ValueType>(exec, settings), settings, mat, precon);
dispatcher.apply(b, x, logdata);
}
Expand Down
40 changes: 20 additions & 20 deletions hip/solver/batch_cg_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ template <typename T>
using settings = gko::kernels::batch_cg::settings<T>;


template <typename HipValueType>
template <typename ValueType>
class kernel_caller {
public:
using value_type = HipValueType;
using hip_value_type = hip_type<ValueType>;

kernel_caller(std::shared_ptr<const DefaultExecutor> exec,
const settings<remove_complex<value_type>> settings)
const settings<remove_complex<ValueType>> settings)
: exec_{exec}, settings_{settings}
{}

Expand All @@ -96,26 +96,27 @@ class kernel_caller {
const gko::kernels::batch_cg::storage_config& sconf, LogType& logger,
PrecType& prec, const BatchMatrixType& mat,
const value_type* const __restrict__ b_values,
value_type* const __restrict__ x_values,
value_type* const __restrict__ workspace_data, const int& block_size,
const size_t& shared_size) const
hip_value_type* const __restrict__ x_values,
hip_value_type* const __restrict__ workspace_data,
const int& block_size, const size_t& shared_size) const
{
apply_kernel<StopType, n_shared, prec_shared_bool>
<<<mat.num_batch_items, block_size, shared_size,
exec_->get_stream()>>>(sconf, settings_.max_iterations,
settings_.residual_tol, logger, prec, mat,
b_values, x_values, workspace_data);
as_hip_type(settings_.residual_tol),
logger, prec, mat, b_values, x_values,
workspace_data);
}


template <typename BatchMatrixType, typename PrecType, typename StopType,
typename LogType>
void call_kernel(
LogType logger, const BatchMatrixType& mat, PrecType prec,
const gko::batch::multi_vector::uniform_batch<const value_type>& b,
const gko::batch::multi_vector::uniform_batch<value_type>& x) const
const gko::batch::multi_vector::uniform_batch<const hip_value_type>& b,
const gko::batch::multi_vector::uniform_batch<hip_value_type>& x) const
{
using real_type = gko::remove_complex<value_type>;
using real_type = gko::remove_complex<hip_value_type>;
const size_type num_batch_items = mat.num_batch_items;
constexpr int align_multiple = 8;
const int padded_num_rows =
Expand All @@ -134,18 +135,18 @@ class kernel_caller {
padded_num_rows, mat.get_single_item_num_nnz());
const auto sconf =
gko::kernels::batch_cg::compute_shared_storage<PrecType,
value_type>(
hip_value_type>(
shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(),
b.num_rhs);
const size_t shared_size =
sconf.n_shared * padded_num_rows * sizeof(value_type) +
sconf.n_shared * padded_num_rows * sizeof(hip_value_type) +
(sconf.prec_shared ? prec_size : 0);
auto workspace = gko::array<value_type>(
auto workspace = gko::array<hip_value_type>(
exec_,
sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type));
GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(value_type) == 0);
sconf.gmem_stride_bytes * num_batch_items / sizeof(hip_value_type));
GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(hip_value_type) == 0);

value_type* const workspace_data = workspace.get_data();
hip_value_type* const workspace_data = workspace.get_data();

// Template parameters launch_apply_kernel<StopType, n_shared,
// prec_shared)
Expand Down Expand Up @@ -193,7 +194,7 @@ class kernel_caller {

private:
std::shared_ptr<const DefaultExecutor> exec_;
const settings<remove_complex<value_type>> settings_;
const settings<remove_complex<ValueType>> settings_;
};


Expand All @@ -206,9 +207,8 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
batch::MultiVector<ValueType>* const x,
batch::log::detail::log_data<remove_complex<ValueType>>& logdata)
{
using hip_value_type = hip_type<ValueType>;
auto dispatcher = batch::solver::create_dispatcher<ValueType>(
kernel_caller<hip_value_type>(exec, settings), settings, mat, precon);
kernel_caller<ValueType>(exec, settings), settings, mat, precon);
dispatcher.apply(b, x, logdata);
}

Expand Down

0 comments on commit 390bd02

Please sign in to comment.