Skip to content

Commit

Permalink
review update
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 29, 2019
1 parent 2e4b094 commit 33fb780
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 173 deletions.
7 changes: 4 additions & 3 deletions cuda/base/cusparse_bindings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,8 @@ GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(ValueType, detail::not_implemented);
size_type n, const ValueType *one, const cusparseMatDescr_t descr, \
const ValueType *csrVal, const int32 *csrRowPtr, \
const int32 *csrColInd, cusparseSolveAnalysisInfo_t factor_info, \
ValueType *rhs, int32 rhs_stride, ValueType *sol, int32 sol_stride) \
const ValueType *rhs, int32 rhs_stride, ValueType *sol, \
int32 sol_stride) \
{ \
GKO_ASSERT_NO_CUSPARSE_ERRORS( \
CusparseName(handle, trans, m, n, as_culibs_type(one), descr, \
Expand All @@ -862,8 +863,8 @@ GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(ValueType, detail::not_implemented);
size_type n, const ValueType *one, const cusparseMatDescr_t descr, \
const ValueType *csrVal, const int64 *csrRowPtr, \
const int64 *csrColInd, cusparseSolveAnalysisInfo_t factor_info, \
ValueType *rhs, int64 rhs_stride, ValueType *sol, int64 sol_stride) \
GKO_NOT_IMPLEMENTED; \
const ValueType *rhs, int64 rhs_stride, ValueType *sol, \
int64 sol_stride) GKO_NOT_IMPLEMENTED; \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")
Expand Down
221 changes: 122 additions & 99 deletions cuda/solver/common_trs_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ namespace gko {
namespace solver {


struct SolveStruct {};
struct SolveStruct {
virtual void dummy() {}
};


namespace cuda {
Expand Down Expand Up @@ -88,10 +90,15 @@ struct SolveStruct : gko::solver::SolveStruct {
algorithm = 0;
policy = CUSPARSE_SOLVE_POLICY_USE_LEVEL;
}
SolveStruct(const SolveStruct &) : SolveStruct() {}
SolveStruct(SolveStruct &&) : SolveStruct() {}
SolveStruct &operator=(const SolveStruct &) { return *this; }
SolveStruct &operator=(SolveStruct &&) { return *this; }

SolveStruct(const SolveStruct &) = delete;

SolveStruct(SolveStruct &&) = delete;

SolveStruct &operator=(const SolveStruct &) = delete;

SolveStruct &operator=(SolveStruct &&) = delete;

~SolveStruct()
{
cusparseDestroyMatDescr(factor_descr);
Expand Down Expand Up @@ -124,10 +131,15 @@ struct SolveStruct : gko::solver::SolveStruct {
GKO_ASSERT_NO_CUSPARSE_ERRORS(
cusparseSetMatDiagType(factor_descr, CUSPARSE_DIAG_TYPE_NON_UNIT));
}
SolveStruct(const SolveStruct &) : SolveStruct() {}
SolveStruct(SolveStruct &&) : SolveStruct() {}
SolveStruct &operator=(const SolveStruct &) { return *this; }
SolveStruct &operator=(SolveStruct &&) { return *this; }

SolveStruct(const SolveStruct &) = delete;

SolveStruct(SolveStruct &&) = delete;

SolveStruct &operator=(const SolveStruct &) = delete;

SolveStruct &operator=(SolveStruct &&) = delete;

~SolveStruct()
{
cusparseDestroyMatDescr(factor_descr);
Expand Down Expand Up @@ -170,8 +182,7 @@ void should_perform_transpose_kernel(std::shared_ptr<const CudaExecutor> exec,
void init_struct_kernel(std::shared_ptr<const CudaExecutor> exec,
std::shared_ptr<solver::SolveStruct> &solve_struct)
{
solve_struct = std::dynamic_pointer_cast<solver::SolveStruct>(
std::make_shared<solver::cuda::SolveStruct>());
solve_struct = std::make_shared<solver::cuda::SolveStruct>();
}


Expand All @@ -181,70 +192,74 @@ void generate_kernel(std::shared_ptr<const CudaExecutor> exec,
solver::SolveStruct *solve_struct,
const gko::size_type num_rhs, bool is_upper)
{
auto cuda_solve_struct =
reinterpret_cast<solver::cuda::SolveStruct *>(solve_struct);
if (cusparse::is_supported<ValueType, IndexType>::value) {
auto handle = exec->get_cusparse_handle();
if (is_upper) {
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetMatFillMode(
cuda_solve_struct->factor_descr, CUSPARSE_FILL_MODE_UPPER));
}
if (auto cuda_solve_struct =
dynamic_cast<solver::cuda::SolveStruct *>(solve_struct)) {
auto handle = exec->get_cusparse_handle();
if (is_upper) {
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetMatFillMode(
cuda_solve_struct->factor_descr, CUSPARSE_FILL_MODE_UPPER));
}


#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020))


ValueType one = 1.0;

{
cusparse::pointer_mode_guard pm_guard(handle);
cusparse::buffer_size_ext(
handle, cuda_solve_struct->algorithm,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE,
matrix->get_size()[0], num_rhs,
matrix->get_num_stored_elements(), &one,
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
nullptr, num_rhs, cuda_solve_struct->solve_info,
cuda_solve_struct->policy,
&cuda_solve_struct->factor_work_size);

// allocate workspace
if (cuda_solve_struct->factor_work_vec != nullptr) {
exec->free(cuda_solve_struct->factor_work_vec);
ValueType one = 1.0;

{
cusparse::pointer_mode_guard pm_guard(handle);
cusparse::buffer_size_ext(
handle, cuda_solve_struct->algorithm,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0],
num_rhs, matrix->get_num_stored_elements(), &one,
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
nullptr, num_rhs, cuda_solve_struct->solve_info,
cuda_solve_struct->policy,
&cuda_solve_struct->factor_work_size);

// allocate workspace
if (cuda_solve_struct->factor_work_vec != nullptr) {
exec->free(cuda_solve_struct->factor_work_vec);
}
cuda_solve_struct->factor_work_vec =
exec->alloc<void *>(cuda_solve_struct->factor_work_size);

cusparse::csrsm2_analysis(
handle, cuda_solve_struct->algorithm,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0],
num_rhs, matrix->get_num_stored_elements(), &one,
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
nullptr, num_rhs, cuda_solve_struct->solve_info,
cuda_solve_struct->policy,
cuda_solve_struct->factor_work_vec);
}
cuda_solve_struct->factor_work_vec =
exec->alloc<void *>(cuda_solve_struct->factor_work_size);

cusparse::csrsm2_analysis(
handle, cuda_solve_struct->algorithm,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE,
matrix->get_size()[0], num_rhs,
matrix->get_num_stored_elements(), &one,
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
nullptr, num_rhs, cuda_solve_struct->solve_info,
cuda_solve_struct->policy, cuda_solve_struct->factor_work_vec);
}


#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))


{
cusparse::pointer_mode_guard pm_guard(handle);
cusparse::csrsm_analysis(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0],
matrix->get_num_stored_elements(),
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
cuda_solve_struct->solve_info);
}
{
cusparse::pointer_mode_guard pm_guard(handle);
cusparse::csrsm_analysis(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
matrix->get_size()[0], matrix->get_num_stored_elements(),
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
cuda_solve_struct->solve_info);
}


#endif


} else {
GKO_NOT_SUPPORTED(solve_struct);
}
} else {
GKO_NOT_IMPLEMENTED;
}
Expand All @@ -261,64 +276,72 @@ void solve_kernel(std::shared_ptr<const CudaExecutor> exec,
matrix::Dense<ValueType> *x)
{
using vec = matrix::Dense<ValueType>;
auto cuda_solve_struct =
reinterpret_cast<const solver::cuda::SolveStruct *>(solve_struct);

if (cusparse::is_supported<ValueType, IndexType>::value) {
ValueType one = 1.0;
auto handle = exec->get_cusparse_handle();
if (auto cuda_solve_struct =
dynamic_cast<const solver::cuda::SolveStruct *>(solve_struct)) {
ValueType one = 1.0;
auto handle = exec->get_cusparse_handle();


#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020))


x->copy_from(gko::lend(b));
{
cusparse::pointer_mode_guard pm_guard(handle);
cusparse::csrsm2_solve(
handle, cuda_solve_struct->algorithm,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE,
matrix->get_size()[0], b->get_stride(),
matrix->get_num_stored_elements(), &one,
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
x->get_values(), b->get_stride(), cuda_solve_struct->solve_info,
cuda_solve_struct->policy, cuda_solve_struct->factor_work_vec);
}
x->copy_from(gko::lend(b));
{
cusparse::pointer_mode_guard pm_guard(handle);
cusparse::csrsm2_solve(
handle, cuda_solve_struct->algorithm,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0],
b->get_stride(), matrix->get_num_stored_elements(), &one,
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
x->get_values(), b->get_stride(),
cuda_solve_struct->solve_info, cuda_solve_struct->policy,
cuda_solve_struct->factor_work_vec);
}


#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))


{
cusparse::pointer_mode_guard pm_guard(handle);
if (b->get_stride() == 1) {
auto temp_b = const_cast<ValueType *>(b->get_const_values());
cusparse::csrsm_solve(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
matrix->get_size()[0], b->get_stride(), &one,
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
cuda_solve_struct->solve_info, temp_b, b->get_size()[0],
x->get_values(), x->get_size()[0]);
} else {
dense::transpose(exec, trans_b, b);
dense::transpose(exec, trans_x, x);
cusparse::csrsm_solve(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
matrix->get_size()[0], trans_b->get_size()[0], &one,
cuda_solve_struct->factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
cuda_solve_struct->solve_info, trans_b->get_values(),
trans_b->get_size()[1], trans_x->get_values(),
trans_x->get_size()[1]);
dense::transpose(exec, x, trans_x);
{
cusparse::pointer_mode_guard pm_guard(handle);
if (b->get_stride() == 1) {
cusparse::csrsm_solve(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
matrix->get_size()[0], b->get_stride(), &one,
cuda_solve_struct->factor_descr,
matrix->get_const_values(),
matrix->get_const_row_ptrs(),
matrix->get_const_col_idxs(),
cuda_solve_struct->solve_info, b->get_const_values(),
b->get_size()[0], x->get_values(), x->get_size()[0]);
} else {
dense::transpose(exec, trans_b, b);
dense::transpose(exec, trans_x, x);
cusparse::csrsm_solve(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
matrix->get_size()[0], trans_b->get_size()[0], &one,
cuda_solve_struct->factor_descr,
matrix->get_const_values(),
matrix->get_const_row_ptrs(),
matrix->get_const_col_idxs(),
cuda_solve_struct->solve_info, trans_b->get_values(),
trans_b->get_size()[1], trans_x->get_values(),
trans_x->get_size()[1]);
dense::transpose(exec, x, trans_x);
}
}
}


#endif


} else {
GKO_NOT_SUPPORTED(solve_struct);
}
} else {
GKO_NOT_IMPLEMENTED;
}
Expand Down
11 changes: 6 additions & 5 deletions hip/base/hipsparse_bindings.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,9 @@ GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(ValueType, detail::not_implemented);
hipsparseHandle_t handle, hipsparseOperation_t trans, size_type m, \
size_type nnz, const ValueType *one, const hipsparseMatDescr_t descr, \
const ValueType *csrVal, const int32 *csrRowPtr, \
const int32 *csrColInd, csrsv2Info_t factor_info, ValueType *rhs, \
ValueType *sol, hipsparseSolvePolicy_t policy, void *factor_work_vec) \
const int32 *csrColInd, csrsv2Info_t factor_info, \
const ValueType *rhs, ValueType *sol, hipsparseSolvePolicy_t policy, \
void *factor_work_vec) \
{ \
GKO_ASSERT_NO_HIPSPARSE_ERRORS( \
HipsparseName(handle, trans, m, nnz, as_hiplibs_type(one), descr, \
Expand All @@ -406,9 +407,9 @@ GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(ValueType, detail::not_implemented);
hipsparseHandle_t handle, hipsparseOperation_t trans, size_type m, \
size_type nnz, const ValueType *one, const hipsparseMatDescr_t descr, \
const ValueType *csrVal, const int64 *csrRowPtr, \
const int64 *csrColInd, csrsv2Info_t factor_info, ValueType *rhs, \
ValueType *sol, hipsparseSolvePolicy_t policy, void *factor_work_vec) \
GKO_NOT_IMPLEMENTED; \
const int64 *csrColInd, csrsv2Info_t factor_info, \
const ValueType *rhs, ValueType *sol, hipsparseSolvePolicy_t policy, \
void *factor_work_vec) GKO_NOT_IMPLEMENTED; \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")
Expand Down
Loading

0 comments on commit 33fb780

Please sign in to comment.