Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triangular solvers #1193

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions common/cuda_hip/components/volatile.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,22 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

template <typename ValueType, typename IndexType>
__device__ __forceinline__
std::enable_if_t<!is_complex_s<ValueType>::value, ValueType>
std::enable_if_t<std::is_floating_point<ValueType>::value, ValueType>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this have to do with int not having a definition of is_complex? We will probably run into issues with #1257

load(const ValueType* values, IndexType index)
{
const volatile ValueType* val = values + index;
return *val;
}

template <typename ValueType>
__device__ __forceinline__
std::enable_if_t<std::is_integral<ValueType>::value, ValueType>
load(const ValueType* values, int index)
{
const volatile ValueType* val = values + index;
return *val;
}

template <typename ValueType, typename IndexType>
__device__ __forceinline__ std::enable_if_t<
std::is_floating_point<ValueType>::value, thrust::complex<ValueType>>
Expand All @@ -50,9 +59,18 @@ load(const thrust::complex<ValueType>* values, IndexType index)
}

template <typename ValueType, typename IndexType>
__device__ __forceinline__
std::enable_if_t<!is_complex_s<ValueType>::value, void>
store(ValueType* values, IndexType index, ValueType value)
__device__ __forceinline__ void store(
ValueType* values, IndexType index,
std::enable_if_t<std::is_floating_point<ValueType>::value, ValueType> value)
{
volatile ValueType* val = values + index;
*val = value;
}

template <typename ValueType>
__device__ __forceinline__ void store(
ValueType* values, int index,
std::enable_if_t<std::is_integral<ValueType>::value, ValueType> value)
{
volatile ValueType* val = values + index;
*val = value;
Expand Down
8 changes: 4 additions & 4 deletions core/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ void LowerTrs<ValueType, IndexType>::generate()
if (this->get_system_matrix()) {
this->get_executor()->run(lower_trs::make_generate(
this->get_system_matrix().get(), this->solve_struct_,
this->get_parameters().unit_diagonal, parameters_.algorithm,
parameters_.num_rhs));
this->get_parameters().unit_diagonal,
gko::lend(parameters_.strategy), parameters_.num_rhs));
}
}

Expand Down Expand Up @@ -178,8 +178,8 @@ void LowerTrs<ValueType, IndexType>::apply_impl(const LinOp* b, LinOp* x) const
}
exec->run(lower_trs::make_solve(
lend(this->get_system_matrix()), lend(this->solve_struct_),
this->get_parameters().unit_diagonal, parameters_.algorithm,
trans_b, trans_x, dense_b, dense_x));
this->get_parameters().unit_diagonal, trans_b, trans_x, dense_b,
dense_x));
},
b, x);
}
Expand Down
11 changes: 5 additions & 6 deletions core/solver/lower_trs_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,18 @@ namespace lower_trs {
bool& do_transpose)


#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype>* matrix, \
std::shared_ptr<solver::SolveStruct>& solve_struct, \
bool unit_diag, const solver::trisolve_algorithm algorithm, \
#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype>* matrix, \
std::shared_ptr<solver::SolveStruct>& solve_struct, \
bool unit_diag, const solver::trisolve_strategy* strategy, \
const size_type num_rhs)


#define GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(_vtype, _itype) \
void solve(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype>* matrix, \
const solver::SolveStruct* solve_struct, bool unit_diag, \
const solver::trisolve_algorithm algorithm, \
matrix::Dense<_vtype>* trans_b, matrix::Dense<_vtype>* trans_x, \
const matrix::Dense<_vtype>* b, matrix::Dense<_vtype>* x)

Expand Down
8 changes: 4 additions & 4 deletions core/solver/upper_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ void UpperTrs<ValueType, IndexType>::generate()
if (this->get_system_matrix()) {
this->get_executor()->run(upper_trs::make_generate(
this->get_system_matrix().get(), this->solve_struct_,
this->get_parameters().unit_diagonal, parameters_.algorithm,
parameters_.num_rhs));
this->get_parameters().unit_diagonal,
gko::lend(parameters_.strategy), parameters_.num_rhs));
}
}

Expand Down Expand Up @@ -178,8 +178,8 @@ void UpperTrs<ValueType, IndexType>::apply_impl(const LinOp* b, LinOp* x) const
}
exec->run(upper_trs::make_solve(
lend(this->get_system_matrix()), lend(this->solve_struct_),
this->get_parameters().unit_diagonal, parameters_.algorithm,
trans_b, trans_x, dense_b, dense_x));
this->get_parameters().unit_diagonal, trans_b, trans_x, dense_b,
dense_x));
},
b, x);
}
Expand Down
11 changes: 5 additions & 6 deletions core/solver/upper_trs_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,18 @@ namespace upper_trs {
bool& do_transpose)


#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype>* matrix, \
std::shared_ptr<solver::SolveStruct>& solve_struct, \
bool unit_diag, const solver::trisolve_algorithm algorithm, \
#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype>* matrix, \
std::shared_ptr<solver::SolveStruct>& solve_struct, \
bool unit_diag, const solver::trisolve_strategy* strategy, \
const size_type num_rhs)


#define GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(_vtype, _itype) \
void solve(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype>* matrix, \
const solver::SolveStruct* solve_struct, bool unit_diag, \
const solver::trisolve_algorithm algorithm, \
matrix::Dense<_vtype>* trans_b, matrix::Dense<_vtype>* trans_x, \
const matrix::Dense<_vtype>* b, matrix::Dense<_vtype>* x)

Expand Down
1 change: 1 addition & 0 deletions cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ target_sources(ginkgo_cuda
solver/cb_gmres_kernels.cu
solver/idr_kernels.cu
solver/lower_trs_kernels.cu
solver/common_trs_kernels.cu
solver/multigrid_kernels.cu
solver/upper_trs_kernels.cu
stop/criterion_kernels.cu
Expand Down
Loading