Skip to content

Commit

Permalink
hip does not support atomic on 16 bits
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Sep 18, 2024
1 parent a11fb2e commit 4da6a9a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 37 deletions.
25 changes: 17 additions & 8 deletions common/cuda_hip/factorization/par_ic_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,23 @@ void compute_factor(std::shared_ptr<const DefaultExecutor> exec,
auto nnz = l->get_num_stored_elements();
auto num_blocks = ceildiv(nnz, default_block_size);
if (num_blocks > 0) {
for (size_type i = 0; i < iterations; ++i) {
kernel::ic_sweep<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
a_lower->get_const_row_idxs(), a_lower->get_const_col_idxs(),
as_device_type(a_lower->get_const_values()),
l->get_const_row_ptrs(), l->get_const_col_idxs(),
as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()));
#ifdef GKO_COMPILING_HIP
if constexpr (std::is_same<remove_complex<ValueType>, half>::value) {
// HIP does not support 16bit atomic operation
GKO_NOT_SUPPORTED(a_lower);
} else
#endif
{
for (size_type i = 0; i < iterations; ++i) {
kernel::ic_sweep<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
a_lower->get_const_row_idxs(),
a_lower->get_const_col_idxs(),
as_device_type(a_lower->get_const_values()),
l->get_const_row_ptrs(), l->get_const_col_idxs(),
as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()));
}
}
}
}
Expand Down
30 changes: 20 additions & 10 deletions common/cuda_hip/factorization/par_ilu_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,26 @@ void compute_l_u_factors(std::shared_ptr<const DefaultExecutor> exec,
const auto grid_dim = static_cast<uint32>(
ceildiv(num_elements, static_cast<size_type>(block_size)));
if (grid_dim > 0) {
for (size_type i = 0; i < iterations; ++i) {
kernel::compute_l_u_factors<<<grid_dim, block_size, 0,
exec->get_stream()>>>(
num_elements, system_matrix->get_const_row_idxs(),
system_matrix->get_const_col_idxs(),
as_device_type(system_matrix->get_const_values()),
l_factor->get_const_row_ptrs(), l_factor->get_const_col_idxs(),
as_device_type(l_factor->get_values()),
u_factor->get_const_row_ptrs(), u_factor->get_const_col_idxs(),
as_device_type(u_factor->get_values()));
#ifdef GKO_COMPILING_HIP
if constexpr (std::is_same<remove_complex<ValueType>, half>::value) {
// HIP does not support 16bit atomic operation
GKO_NOT_SUPPORTED(system_matrix);
} else
#endif
{
for (size_type i = 0; i < iterations; ++i) {
kernel::compute_l_u_factors<<<grid_dim, block_size, 0,
exec->get_stream()>>>(
num_elements, system_matrix->get_const_row_idxs(),
system_matrix->get_const_col_idxs(),
as_device_type(system_matrix->get_const_values()),
l_factor->get_const_row_ptrs(),
l_factor->get_const_col_idxs(),
as_device_type(l_factor->get_values()),
u_factor->get_const_row_ptrs(),
u_factor->get_const_col_idxs(),
as_device_type(u_factor->get_values()));
}
}
}
}
Expand Down
22 changes: 15 additions & 7 deletions hip/factorization/par_ict_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,21 @@ void compute_factor(syn::value_list<int, subwarp_size>,
auto block_size = default_block_size / subwarp_size;
auto num_blocks = ceildiv(total_nnz, block_size);
if (num_blocks > 0) {
kernel::ict_sweep<subwarp_size>
<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
a->get_const_row_ptrs(), a->get_const_col_idxs(),
as_device_type(a->get_const_values()), l->get_const_row_ptrs(),
l_coo->get_const_row_idxs(), l->get_const_col_idxs(),
as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()));
#ifdef GKO_COMPILING_HIP
if constexpr (std::is_same<remove_complex<ValueType>, half>::value) {
// HIP does not support 16bit atomic operation
GKO_NOT_SUPPORTED(l);
} else
#endif
{
kernel::ict_sweep<subwarp_size>
<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
a->get_const_row_ptrs(), a->get_const_col_idxs(),
as_device_type(a->get_const_values()),
l->get_const_row_ptrs(), l_coo->get_const_row_idxs(),
l->get_const_col_idxs(), as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()));
}
}
}

Expand Down
32 changes: 20 additions & 12 deletions hip/factorization/par_ilut_sweep_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,26 @@ void compute_l_u_factors(syn::value_list<int, subwarp_size>,
auto block_size = default_block_size / subwarp_size;
auto num_blocks = ceildiv(total_nnz, block_size);
if (num_blocks > 0) {
kernel::sweep<subwarp_size>
<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
a->get_const_row_ptrs(), a->get_const_col_idxs(),
as_device_type(a->get_const_values()), l->get_const_row_ptrs(),
l_coo->get_const_row_idxs(), l->get_const_col_idxs(),
as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()),
u_coo->get_const_row_idxs(), u_coo->get_const_col_idxs(),
as_device_type(u->get_values()), u_csc->get_const_row_ptrs(),
u_csc->get_const_col_idxs(),
as_device_type(u_csc->get_values()),
static_cast<IndexType>(u->get_num_stored_elements()));
#ifdef GKO_COMPILING_HIP
if constexpr (std::is_same<remove_complex<ValueType>, half>::value) {
// HIP does not support 16bit atomic operation
GKO_NOT_SUPPORTED(a);
} else
#endif
{
kernel::sweep<subwarp_size>
<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
a->get_const_row_ptrs(), a->get_const_col_idxs(),
as_device_type(a->get_const_values()),
l->get_const_row_ptrs(), l_coo->get_const_row_idxs(),
l->get_const_col_idxs(), as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()),
u_coo->get_const_row_idxs(), u_coo->get_const_col_idxs(),
as_device_type(u->get_values()),
u_csc->get_const_row_ptrs(), u_csc->get_const_col_idxs(),
as_device_type(u_csc->get_values()),
static_cast<IndexType>(u->get_num_stored_elements()));
}
}
}

Expand Down

0 comments on commit 4da6a9a

Please sign in to comment.