Skip to content

Commit

Permalink
add dpcpp csr diagonal missing components
Browse files Browse the repository at this point in the history
- check_diagonal_entries
- add_scaled_identity
  • Loading branch information
yhmtsai committed Oct 18, 2023
1 parent 9f71bcd commit b88f097
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 9 deletions.
102 changes: 99 additions & 3 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,74 @@ void extract_diagonal(size_type diag_size, size_type nnz,
GKO_ENABLE_DEFAULT_HOST(extract_diagonal, extract_diagonal);


template <typename IndexType>
void check_diagonal_entries(const IndexType num_min_rows_cols,
const IndexType* const __restrict__ row_ptrs,
const IndexType* const __restrict__ col_idxs,
bool* const __restrict__ has_all_diags,
sycl::nd_item<3> item_ct1)
{
constexpr int subgroup_size = config::warp_size;
auto tile_grp = group::tiled_partition<subgroup_size>(
group::this_thread_block(item_ct1));
const auto row =
thread::get_subwarp_id_flat<subgroup_size, IndexType>(item_ct1);
if (row < num_min_rows_cols) {
const auto tid_in_warp = tile_grp.thread_rank();
const auto row_start = row_ptrs[row];
const auto num_nz = row_ptrs[row + 1] - row_start;
bool row_has_diag_local{false};
for (IndexType iz = tid_in_warp; iz < num_nz; iz += subgroup_size) {
if (col_idxs[iz + row_start] == row) {
row_has_diag_local = true;
break;
}
}
auto row_has_diag = static_cast<bool>(tile_grp.any(row_has_diag_local));
if (!row_has_diag) {
if (tile_grp.thread_rank() == 0) {
*has_all_diags = false;
}
return;
}
}
}

GKO_ENABLE_DEFAULT_HOST(check_diagonal_entries, check_diagonal_entries);


template <typename ValueType, typename IndexType>
void add_scaled_identity(const ValueType* const __restrict__ alpha,
const ValueType* const __restrict__ beta,
const IndexType num_rows,
const IndexType* const __restrict__ row_ptrs,
const IndexType* const __restrict__ col_idxs,
ValueType* const __restrict__ values,
sycl::nd_item<3> item_ct1)
{
constexpr int subgroup_size = config::warp_size;
auto tile_grp = group::tiled_partition<subgroup_size>(
group::this_thread_block(item_ct1));
const auto row =
thread::get_subwarp_id_flat<subgroup_size, IndexType>(item_ct1);
const auto num_warps =
thread::get_subwarp_num_flat<subgroup_size, IndexType>(item_ct1);
if (row < num_rows) {
const auto tid_in_warp = tile_grp.thread_rank();
const auto row_start = row_ptrs[row];
const auto num_nz = row_ptrs[row + 1] - row_start;
for (IndexType iz = tid_in_warp; iz < num_nz; iz += subgroup_size) {
values[iz + row_start] *= beta[0];
if (col_idxs[iz + row_start] == row) {
values[iz + row_start] += alpha[0];
}
}
}
}

GKO_ENABLE_DEFAULT_HOST(add_scaled_identity, add_scaled_identity);


} // namespace kernel


Expand Down Expand Up @@ -2364,8 +2432,24 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL);
template <typename ValueType, typename IndexType>
void check_diagonal_entries_exist(
std::shared_ptr<const DpcppExecutor> exec,
const matrix::Csr<ValueType, IndexType>* const mtx,
bool& has_all_diags) GKO_NOT_IMPLEMENTED;
const matrix::Csr<ValueType, IndexType>* const mtx, bool& has_all_diags)
{
const size_type num_subgroup = mtx->get_size()[0];
if (num_subgroup > 0) {
const size_type num_blocks =
num_subgroup / (default_block_size / config::warp_size);
array<bool> has_diags(exec, {true});
kernel::check_diagonal_entries(
num_blocks, default_block_size, 0, exec->get_queue(),
static_cast<IndexType>(
std::min(mtx->get_size()[0], mtx->get_size()[1])),
mtx->get_const_row_ptrs(), mtx->get_const_col_idxs(),
has_diags.get_data());
has_all_diags = exec->copy_val_to_host(has_diags.get_const_data());
} else {
has_all_diags = true;
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST);
Expand All @@ -2376,7 +2460,19 @@ void add_scaled_identity(std::shared_ptr<const DpcppExecutor> exec,
const matrix::Dense<ValueType>* const alpha,
const matrix::Dense<ValueType>* const beta,
matrix::Csr<ValueType, IndexType>* const mtx)
GKO_NOT_IMPLEMENTED;
{
const auto nrows = mtx->get_size()[0];
if (nrows == 0) {
return;
}
const auto nthreads = nrows * config::warp_size;
const auto nblocks = ceildiv(nthreads, default_block_size);
kernel::add_scaled_identity(
nblocks, default_block_size, 0, exec->get_queue(),
alpha->get_const_values(), beta->get_const_values(),
static_cast<IndexType>(nrows), mtx->get_const_row_ptrs(),
mtx->get_const_col_idxs(), mtx->get_values());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL);
Expand Down
6 changes: 0 additions & 6 deletions test/matrix/csr_kernels2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1311,9 +1311,6 @@ TEST_F(Csr, CreateSubMatrixIsEquivalentToRef)
}


#ifndef GKO_COMPILING_DPCPP


TEST_F(Csr, CanDetectMissingDiagonalEntry)
{
using T = double;
Expand Down Expand Up @@ -1359,6 +1356,3 @@ TEST_F(Csr, AddScaledIdentityToNonSquare)

GKO_ASSERT_MTX_NEAR(mtx, dmtx, r<value_type>::value);
}


#endif // GKO_COMPILING_DPCPP

0 comments on commit b88f097

Please sign in to comment.