Skip to content

Commit

Permalink
Add convert_to_with_sorting helper
Browse files Browse the repository at this point in the history
Additionally, use this helper function everywhere a sorting with
conditional sorting is required.
  • Loading branch information
Thomas Grützmacher committed Nov 12, 2020
1 parent a61d8e0 commit 7d28c4f
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 116 deletions.
171 changes: 170 additions & 1 deletion core/base/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GKO_INTERNAL_CORE_BASE_UTILS_HPP_
#define GKO_INTERNAL_CORE_BASE_UTILS_HPP_


#include <memory>
#include <type_traits>


#include <ginkgo/core/base/polymorphic_object.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/base/utils.hpp>
#include <ginkgo/core/matrix/csr.hpp>


namespace gko {
Expand All @@ -50,7 +58,168 @@ GKO_ATTRIBUTES GKO_INLINE ValueType checked_load(const ValueType *p,


} // namespace kernels


namespace detail {


template <typename Dest>
struct conversion_sort_helper {
};

template <typename ValueType, typename IndexType>
struct conversion_sort_helper<matrix::Csr<ValueType, IndexType>> {
using mtx_type = matrix::Csr<ValueType, IndexType>;
template <typename Source>
static std::unique_ptr<mtx_type> get_sorted_conversion(
std::shared_ptr<const Executor> &exec, Source *source)
{
auto editable_mtx = mtx_type::create(exec);
as<ConvertibleTo<mtx_type>>(source)->convert_to(lend(editable_mtx));
editable_mtx->sort_by_column_index();
return editable_mtx;
}
};


template <typename Dest, typename Source>
std::unique_ptr<Dest, std::function<void(Dest *)>> convert_to_with_sorting_impl(
std::shared_ptr<const Executor> &exec, Source *obj, bool skip_sorting)
{
if (skip_sorting) {
return copy_and_convert_to<Dest>(exec, obj);
} else {
using decay_dest = std::decay_t<Dest>;
auto sorted_mtx =
detail::conversion_sort_helper<decay_dest>::get_sorted_conversion(
exec, obj);
return {sorted_mtx.release(), std::default_delete<Dest>()};
}
}

template <typename Dest, typename Source>
std::shared_ptr<Dest> convert_to_with_sorting_impl(
std::shared_ptr<const Executor> &exec, std::shared_ptr<Source> obj,
bool skip_sorting)
{
if (skip_sorting) {
return copy_and_convert_to<Dest>(exec, obj);
} else {
using decay_dest = std::decay_t<Dest>;
auto sorted_mtx =
detail::conversion_sort_helper<decay_dest>::get_sorted_conversion(
exec, obj.get());
return {std::move(sorted_mtx)};
}
}


} // namespace detail


/**
* @internal
*
* Helper function that converts the given matrix to the Dest format with
* additional sorting if requested.
*
* If the given matrix was already sorted, is on the same executor and with a
* dynamic type of `Dest`, the same pointer is returned with an empty
* deleter.
* In all other cases, a new matrix is created, which stores the converted
* matrix.
*
* @tparam Dest the type to which the object should be converted
* @tparam Source the type of the source object
*
* @param exec the executor where the result should be placed
* @param obj the source object that should be converted
* @param skip_sorting indicator if the resulting matrix should be sorted or
* not
*/
template <typename Dest, typename Source>
std::unique_ptr<Dest, std::function<void(Dest *)>> convert_to_with_sorting(
std::shared_ptr<const Executor> exec, Source *obj, bool skip_sorting)
{
return detail::convert_to_with_sorting_impl<Dest>(exec, obj, skip_sorting);
}

/**
* @copydoc convert_to_with_sorting(std::shared_ptr<const Executor>,
* Source *, bool)
*
* @note This version adds the const qualifier for the result since the input is
* also const
*/
template <typename Dest, typename Source>
std::unique_ptr<const Dest, std::function<void(const Dest *)>>
convert_to_with_sorting(std::shared_ptr<const Executor> exec, const Source *obj,
bool skip_sorting)
{
return detail::convert_to_with_sorting_impl<const Dest>(exec, obj,
skip_sorting);
}

/**
* @copydoc convert_to_with_sorting(std::shared_ptr<const Executor>,
* Source *, bool)
*
* @note This version has a unique_ptr as the source instead of a plain pointer
*/
template <typename Dest, typename Source>
std::unique_ptr<Dest, std::function<void(Dest *)>> convert_to_with_sorting(
std::shared_ptr<const Executor> exec, const std::unique_ptr<Source> &obj,
bool skip_sorting)
{
return detail::convert_to_with_sorting_impl<Dest>(exec, obj.get(),
skip_sorting);
}

/**
* @internal
*
* Helper function that converts the given matrix to the Dest format with
* additional sorting if requested.
*
* If the given matrix was already sorted, is on the same executor and with a
* dynamic type of `Dest`, the same pointer is returned.
* In all other cases, a new matrix is created, which stores the converted
* matrix.
*
* @tparam Dest the type to which the object should be converted
* @tparam Source the type of the source object
*
* @param exec the executor where the result should be placed
* @param obj the source object that should be converted
* @param skip_sorting indicator if the resulting matrix should be sorted or
* not
*/
template <typename Dest, typename Source>
std::shared_ptr<Dest> convert_to_with_sorting(
std::shared_ptr<const Executor> exec, std::shared_ptr<Source> obj,
bool skip_sorting)
{
return detail::convert_to_with_sorting_impl<Dest>(exec, obj, skip_sorting);
}

/**
* @copydoc convert_to_with_sorting(std::shared_ptr<const Executor>,
* std::shared_ptr<Source>, bool)
*
* @note This version adds the const qualifier for the result since the input is
* also const
*/
template <typename Dest, typename Source>
std::shared_ptr<const Dest> convert_to_with_sorting(
std::shared_ptr<const Executor> exec, std::shared_ptr<const Source> obj,
bool skip_sorting)
{
return detail::convert_to_with_sorting_impl<const Dest>(exec, obj,
skip_sorting);
}


} // namespace gko


#endif // GKO_INTERNAL_CORE_BASE_UTILS_HPP_
#endif // GKO_INTERNAL_CORE_BASE_UTILS_HPP_
27 changes: 6 additions & 21 deletions core/factorization/par_ict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/matrix/csr.hpp>


#include "core/base/utils.hpp"
#include "core/factorization/factorization_kernels.hpp"
#include "core/factorization/par_ict_kernels.hpp"
#include "core/factorization/par_ilu_kernels.hpp"
Expand Down Expand Up @@ -175,30 +176,14 @@ ParIct<ValueType, IndexType>::generate_l_lt(
const auto exec = this->get_executor();

// convert and/or sort the matrix if necessary
std::unique_ptr<CsrMatrix> csr_system_matrix_unique_ptr{};
auto csr_system_matrix =
dynamic_cast<const CsrMatrix *>(system_matrix.get());
if (csr_system_matrix == nullptr ||
csr_system_matrix->get_executor() != exec) {
csr_system_matrix_unique_ptr = CsrMatrix::create(exec);
as<ConvertibleTo<CsrMatrix>>(system_matrix.get())
->convert_to(csr_system_matrix_unique_ptr.get());
csr_system_matrix = csr_system_matrix_unique_ptr.get();
}
if (!parameters_.skip_sorting) {
if (csr_system_matrix_unique_ptr == nullptr) {
csr_system_matrix_unique_ptr = CsrMatrix::create(exec);
csr_system_matrix_unique_ptr->copy_from(csr_system_matrix);
}
csr_system_matrix_unique_ptr->sort_by_column_index();
csr_system_matrix = csr_system_matrix_unique_ptr.get();
}
auto csr_system_matrix = convert_to_with_sorting<CsrMatrix>(
exec, system_matrix, parameters_.skip_sorting);

// initialize the L matrix data structures
const auto num_rows = csr_system_matrix->get_size()[0];
Array<IndexType> l_row_ptrs_array{exec, num_rows + 1};
auto l_row_ptrs = l_row_ptrs_array.get_data();
exec->run(make_initialize_row_ptrs_l(csr_system_matrix, l_row_ptrs));
exec->run(make_initialize_row_ptrs_l(csr_system_matrix.get(), l_row_ptrs));

auto l_nnz =
static_cast<size_type>(exec->copy_val_to_host(l_row_ptrs + num_rows));
Expand All @@ -209,14 +194,14 @@ ParIct<ValueType, IndexType>::generate_l_lt(
std::move(l_row_ptrs_array));

// initialize L
exec->run(make_initialize_l(csr_system_matrix, l.get(), true));
exec->run(make_initialize_l(csr_system_matrix.get(), l.get(), true));

// compute limit #nnz for L
auto l_nnz_limit =
static_cast<IndexType>(l_nnz * parameters_.fill_in_limit);

ParIctState<ValueType, IndexType> state{exec,
csr_system_matrix,
csr_system_matrix.get(),
std::move(l),
l_nnz_limit,
parameters_.approximate_select,
Expand Down
23 changes: 7 additions & 16 deletions core/factorization/par_ilu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,24 @@ ParIlu<ValueType, IndexType>::generate_l_u(

// Converts the system matrix to CSR.
// Throws an exception if it is not convertible.
auto csr_system_matrix_unique_ptr = CsrMatrix::create(exec);
auto csr_system_matrix = CsrMatrix::create(exec);
as<ConvertibleTo<CsrMatrix>>(system_matrix.get())
->convert_to(csr_system_matrix_unique_ptr.get());
auto csr_system_matrix = csr_system_matrix_unique_ptr.get();
->convert_to(csr_system_matrix.get());
// If necessary, sort it
if (!skip_sorting) {
csr_system_matrix->sort_by_column_index();
}

// Add explicit diagonal zero elements if they are missing
exec->run(par_ilu_factorization::make_add_diagonal_elements(
csr_system_matrix, true));
csr_system_matrix.get(), true));

const auto matrix_size = csr_system_matrix->get_size();
const auto number_rows = matrix_size[0];
Array<IndexType> l_row_ptrs{exec, number_rows + 1};
Array<IndexType> u_row_ptrs{exec, number_rows + 1};
exec->run(par_ilu_factorization::make_initialize_row_ptrs_l_u(
csr_system_matrix, l_row_ptrs.get_data(), u_row_ptrs.get_data()));
csr_system_matrix.get(), l_row_ptrs.get_data(), u_row_ptrs.get_data()));

// Get nnz from device memory
auto l_nnz = static_cast<size_type>(
Expand All @@ -123,7 +122,7 @@ ParIlu<ValueType, IndexType>::generate_l_u(
std::move(u_row_ptrs), u_strategy);

exec->run(par_ilu_factorization::make_initialize_l_u(
csr_system_matrix, l_factor.get(), u_factor.get()));
csr_system_matrix.get(), l_factor.get(), u_factor.get()));

// We use `transpose()` here to convert the Csr format to Csc.
auto u_factor_transpose_lin_op = u_factor->transpose();
Expand All @@ -140,18 +139,10 @@ ParIlu<ValueType, IndexType>::generate_l_u(

// If it was not, and we already own a CSR `system_matrix`,
// we can move the Csr matrix to Coo, which has very little overhead.
// Otherwise, we convert from the Csr matrix, since it is the conversion
// with the least overhead.
// We also have to convert / move from the CSR matrix if it was not already
// sorted (in which case we definitively own a CSR `system_matrix`).
// We also have to move from the CSR matrix if it was not already sorted.
if (!skip_sorting || coo_system_matrix_ptr == nullptr) {
coo_system_matrix_unique_ptr = CooMatrix::create(exec);
if (csr_system_matrix_unique_ptr == nullptr) {
csr_system_matrix->convert_to(coo_system_matrix_unique_ptr.get());
} else {
csr_system_matrix_unique_ptr->move_to(
coo_system_matrix_unique_ptr.get());
}
csr_system_matrix->move_to(coo_system_matrix_unique_ptr.get());
coo_system_matrix_ptr = coo_system_matrix_unique_ptr.get();
}

Expand Down
29 changes: 7 additions & 22 deletions core/factorization/par_ilut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/matrix/csr.hpp>


#include "core/base/utils.hpp"
#include "core/factorization/factorization_kernels.hpp"
#include "core/factorization/par_ilu_kernels.hpp"
#include "core/factorization/par_ilut_kernels.hpp"
Expand Down Expand Up @@ -191,32 +192,16 @@ ParIlut<ValueType, IndexType>::generate_l_u(
const auto exec = this->get_executor();

// convert and/or sort the matrix if necessary
std::unique_ptr<CsrMatrix> csr_system_matrix_unique_ptr{};
auto csr_system_matrix =
dynamic_cast<const CsrMatrix *>(system_matrix.get());
if (csr_system_matrix == nullptr ||
csr_system_matrix->get_executor() != exec) {
csr_system_matrix_unique_ptr = CsrMatrix::create(exec);
as<ConvertibleTo<CsrMatrix>>(system_matrix.get())
->convert_to(csr_system_matrix_unique_ptr.get());
csr_system_matrix = csr_system_matrix_unique_ptr.get();
}
if (!parameters_.skip_sorting) {
if (csr_system_matrix_unique_ptr == nullptr) {
csr_system_matrix_unique_ptr = CsrMatrix::create(exec);
csr_system_matrix_unique_ptr->copy_from(csr_system_matrix);
}
csr_system_matrix_unique_ptr->sort_by_column_index();
csr_system_matrix = csr_system_matrix_unique_ptr.get();
}
auto csr_system_matrix = convert_to_with_sorting<CsrMatrix>(
exec, system_matrix, parameters_.skip_sorting);

// initialize the L and U matrix data structures
const auto num_rows = csr_system_matrix->get_size()[0];
Array<IndexType> l_row_ptrs_array{exec, num_rows + 1};
Array<IndexType> u_row_ptrs_array{exec, num_rows + 1};
auto l_row_ptrs = l_row_ptrs_array.get_data();
auto u_row_ptrs = u_row_ptrs_array.get_data();
exec->run(make_initialize_row_ptrs_l_u(csr_system_matrix, l_row_ptrs,
exec->run(make_initialize_row_ptrs_l_u(csr_system_matrix.get(), l_row_ptrs,
u_row_ptrs));

auto l_nnz =
Expand All @@ -233,7 +218,7 @@ ParIlut<ValueType, IndexType>::generate_l_u(
std::move(u_row_ptrs_array));

// initialize L and U
exec->run(make_initialize_l_u(csr_system_matrix, l.get(), u.get()));
exec->run(make_initialize_l_u(csr_system_matrix.get(), l.get(), u.get()));

// compute limit #nnz for L and U
auto l_nnz_limit =
Expand All @@ -242,7 +227,7 @@ ParIlut<ValueType, IndexType>::generate_l_u(
static_cast<IndexType>(u_nnz * parameters_.fill_in_limit);

ParIlutState<ValueType, IndexType> state{exec,
csr_system_matrix,
csr_system_matrix.get(),
std::move(l),
std::move(u),
l_nnz_limit,
Expand Down Expand Up @@ -352,4 +337,4 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_PAR_ILUT);


} // namespace factorization
} // namespace gko
} // namespace gko
Loading

0 comments on commit 7d28c4f

Please sign in to comment.