Skip to content

Commit

Permalink
add cusparse csrsort bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Nov 27, 2019
1 parent bb8b830 commit bbf7883
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 7 deletions.
74 changes: 74 additions & 0 deletions cuda/base/cusparse_bindings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,80 @@ GKO_BIND_CUSPARSE64_CSRSM_SOLVE(ValueType, detail::not_implemented);
#endif


template <typename IndexType>
void create_identity_permutation(cusparseHandle_t handle, IndexType size,
IndexType *permutation) GKO_NOT_IMPLEMENTED;

template <>
inline void create_identity_permutation<int32>(cusparseHandle_t handle,
int32 size, int32 *permutation)
{
GKO_ASSERT_NO_CUSPARSE_ERRORS(
cusparseCreateIdentityPermutation(handle, size, permutation));
}


template <typename IndexType>
void csrsort_buffer_size(cusparseHandle_t handle, IndexType m, IndexType n,
IndexType nnz, const IndexType *row_ptrs,
const IndexType *col_idxs,
size_type &buffer_size) GKO_NOT_IMPLEMENTED;

template <>
inline void csrsort_buffer_size<int32>(cusparseHandle_t handle, int32 m,
int32 n, int32 nnz,
const int32 *row_ptrs,
const int32 *col_idxs,
size_type &buffer_size)
{
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseXcsrsort_bufferSizeExt(
handle, m, n, nnz, row_ptrs, col_idxs, &buffer_size));
}


template <typename IndexType>
void csrsort(cusparseHandle_t handle, IndexType m, IndexType n, IndexType nnz,
const cusparseMatDescr_t descr, const IndexType *row_ptrs,
IndexType *col_idxs, IndexType *permutation,
void *buffer) GKO_NOT_IMPLEMENTED;

template <>
inline void csrsort<int32>(cusparseHandle_t handle, int32 m, int32 n, int32 nnz,
const cusparseMatDescr_t descr,
const int32 *row_ptrs, int32 *col_idxs,
int32 *permutation, void *buffer)
{
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseXcsrsort(
handle, m, n, nnz, descr, row_ptrs, col_idxs, permutation, buffer));
}


template <typename IndexType, typename ValueType>
void gather(cusparseHandle_t handle, IndexType nnz, const ValueType *in,
ValueType *out, const IndexType *permutation) GKO_NOT_IMPLEMENTED;

#define GKO_BIND_CUSPARSE_GATHER(ValueType, CusparseName) \
template <> \
inline void gather<int32, ValueType>(cusparseHandle_t handle, int32 nnz, \
const ValueType *in, ValueType *out, \
const int32 *permutation) \
{ \
GKO_ASSERT_NO_CUSPARSE_ERRORS( \
CusparseName(handle, nnz, as_culibs_type(in), as_culibs_type(out), \
permutation, CUSPARSE_INDEX_BASE_ZERO)); \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")

GKO_BIND_CUSPARSE_GATHER(float, cusparseSgthr);
GKO_BIND_CUSPARSE_GATHER(double, cusparseDgthr);
GKO_BIND_CUSPARSE_GATHER(std::complex<float>, cusparseCgthr);
GKO_BIND_CUSPARSE_GATHER(std::complex<double>, cusparseZgthr);

#undef GKO_BIND_CUSPARSE_GATHER


} // namespace cusparse
} // namespace cuda
} // namespace kernels
Expand Down
41 changes: 40 additions & 1 deletion cuda/matrix/csr_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,46 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
template <typename ValueType, typename IndexType>
void sort_by_column_index(std::shared_ptr<const CudaExecutor> exec,
matrix::Csr<ValueType, IndexType> *to_sort)
GKO_NOT_IMPLEMENTED;
{
if (cusparse::is_supported<ValueType, IndexType>::value) {
auto handle = exec->get_cusparse_handle();
auto descr = cusparse::create_mat_descr();
auto m = IndexType(to_sort->get_size()[0]);
auto n = IndexType(to_sort->get_size()[1]);
auto nnz = IndexType(to_sort->get_num_stored_elements());
auto row_ptrs = to_sort->get_const_row_ptrs();
auto col_idxs = to_sort->get_col_idxs();
auto vals = to_sort->get_values();

// copy values
Array<ValueType> tmp_vals_array(exec, nnz);
exec->copy_from(exec.get(), nnz, vals, tmp_vals_array.get_data());
auto tmp_vals = tmp_vals_array.get_const_data();

// init identity permutation
Array<IndexType> permutation_array(exec, nnz);
auto permutation = permutation_array.get_data();
cusparse::create_identity_permutation(handle, nnz, permutation);

// allocate buffer
size_type buffer_size{};
cusparse::csrsort_buffer_size(handle, m, n, nnz, row_ptrs, col_idxs,
buffer_size);
Array<char> buffer_array{exec, buffer_size};
auto buffer = buffer_array.get_data();

// sort column indices
cusparse::csrsort(handle, m, n, nnz, descr, row_ptrs, col_idxs,
permutation, buffer);

// sort values
cusparse::gather(handle, nnz, tmp_vals, vals, permutation);

cusparse::destroy(descr);
} else {
GKO_NOT_IMPLEMENTED;
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX);
Expand Down
70 changes: 64 additions & 6 deletions cuda/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Csr : public ::testing::Test {
using ComplexVec = gko::matrix::Dense<std::complex<double>>;
using ComplexMtx = gko::matrix::Csr<std::complex<double>>;

Csr() : rand_engine(42) {}
Csr() : mtx_size(532, 231), rand_engine(42) {}

void SetUp()
{
Expand Down Expand Up @@ -93,11 +93,11 @@ class Csr : public ::testing::Test {
int num_vectors = 1)
{
mtx = Mtx::create(ref, strategy);
mtx->copy_from(gen_mtx<Vec>(532, 231, 1));
mtx->copy_from(gen_mtx<Vec>(mtx_size[0], mtx_size[1], 1));
square_mtx = Mtx::create(ref, strategy);
square_mtx->copy_from(gen_mtx<Vec>(532, 532, 1));
expected = gen_mtx<Vec>(532, num_vectors, 1);
y = gen_mtx<Vec>(231, num_vectors, 1);
square_mtx->copy_from(gen_mtx<Vec>(mtx_size[0], mtx_size[0], 1));
expected = gen_mtx<Vec>(mtx_size[0], num_vectors, 1);
y = gen_mtx<Vec>(mtx_size[1], num_vectors, 1);
alpha = gko::initialize<Vec>({2.0}, ref);
beta = gko::initialize<Vec>({-1.0}, ref);
dmtx = Mtx::create(cuda, strategy);
Expand All @@ -118,14 +118,48 @@ class Csr : public ::testing::Test {
std::shared_ptr<ComplexMtx::strategy_type> strategy)
{
complex_mtx = ComplexMtx::create(ref, strategy);
complex_mtx->copy_from(gen_mtx<ComplexVec>(532, 231, 1));
complex_mtx->copy_from(
gen_mtx<ComplexVec>(mtx_size[0], mtx_size[1], 1));
complex_dmtx = ComplexMtx::create(cuda, strategy);
complex_dmtx->copy_from(complex_mtx.get());
}

struct matrix_pair {
std::unique_ptr<Mtx> ref;
std::unique_ptr<Mtx> cuda;
};

matrix_pair gen_unsorted_mtx()
{
constexpr int min_nnz_per_row = 2; // Must be larger/equal than 2
auto local_mtx_ref =
gen_mtx<Mtx>(mtx_size[0], mtx_size[1], min_nnz_per_row);
for (size_t row = 0; row < mtx_size[0]; ++row) {
const auto row_ptrs = local_mtx_ref->get_const_row_ptrs();
const auto start_row = row_ptrs[row];
auto col_idx = local_mtx_ref->get_col_idxs() + start_row;
auto vals = local_mtx_ref->get_values() + start_row;
const auto nnz_in_this_row = row_ptrs[row + 1] - row_ptrs[row];
auto swap_idx_dist =
std::uniform_int_distribution<>(0, nnz_in_this_row - 1);
// shuffle `nnz_in_this_row / 2` times
for (size_t perm = 0; perm < nnz_in_this_row; perm += 2) {
const auto idx1 = swap_idx_dist(rand_engine);
const auto idx2 = swap_idx_dist(rand_engine);
std::swap(col_idx[idx1], col_idx[idx2]);
std::swap(vals[idx1], vals[idx2]);
}
}
auto local_mtx_cuda = Mtx::create(cuda);
local_mtx_cuda->copy_from(local_mtx_ref.get());

return {std::move(local_mtx_ref), std::move(local_mtx_cuda)};
}

std::shared_ptr<gko::ReferenceExecutor> ref;
std::shared_ptr<const gko::CudaExecutor> cuda;

const gko::dim<2> mtx_size;
std::ranlux48 rand_engine;

std::unique_ptr<Mtx> mtx;
Expand Down Expand Up @@ -576,4 +610,28 @@ TEST_F(Csr, MoveToHybridIsEquivalentToRef)
}


TEST_F(Csr, SortSortedMatrixIsEquivalentToRef)
{
set_up_apply_data(std::make_shared<Mtx::automatical>());

mtx->sort_by_column_index();
dmtx->sort_by_column_index();

// Values must be unchanged, therefore, tolerance is `0`
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 0);
}


TEST_F(Csr, SortUnsortedMatrixIsEquivalentToRef)
{
auto uns_mtx = gen_unsorted_mtx();

uns_mtx.ref->sort_by_column_index();
uns_mtx.cuda->sort_by_column_index();

// Values must be unchanged, therefore, tolerance is `0`
GKO_ASSERT_MTX_NEAR(uns_mtx.ref, uns_mtx.cuda, 0);
}


} // namespace

0 comments on commit bbf7883

Please sign in to comment.