From 7d28c4fe5bc096622d6aa39517bfd78157f40fac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gr=C3=BCtzmacher?= Date: Thu, 12 Nov 2020 10:51:51 +0100 Subject: [PATCH] Add convert_to_with_sorting helper Additionally, use this helper function everywhere a sorting with conditional sorting is required. --- core/base/utils.hpp | 171 +++++++++++++++++++++++++- core/factorization/par_ict.cpp | 27 +---- core/factorization/par_ilu.cpp | 23 ++-- core/factorization/par_ilut.cpp | 29 ++--- core/preconditioner/isai.cpp | 45 +------ core/preconditioner/jacobi.cpp | 16 +-- reference/test/base/CMakeLists.txt | 1 + reference/test/base/utils.cpp | 187 +++++++++++++++++++++++++++++ 8 files changed, 383 insertions(+), 116 deletions(-) create mode 100644 reference/test/base/utils.cpp diff --git a/core/base/utils.hpp b/core/base/utils.hpp index 4e6fbc1dfce..ace930bd116 100644 --- a/core/base/utils.hpp +++ b/core/base/utils.hpp @@ -33,7 +33,15 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef GKO_INTERNAL_CORE_BASE_UTILS_HPP_ #define GKO_INTERNAL_CORE_BASE_UTILS_HPP_ + +#include +#include + + +#include #include +#include +#include namespace gko { @@ -50,7 +58,168 @@ GKO_ATTRIBUTES GKO_INLINE ValueType checked_load(const ValueType *p, } // namespace kernels + + +namespace detail { + + +template +struct conversion_sort_helper { +}; + +template +struct conversion_sort_helper> { + using mtx_type = matrix::Csr; + template + static std::unique_ptr get_sorted_conversion( + std::shared_ptr &exec, Source *source) + { + auto editable_mtx = mtx_type::create(exec); + as>(source)->convert_to(lend(editable_mtx)); + editable_mtx->sort_by_column_index(); + return editable_mtx; + } +}; + + +template +std::unique_ptr> convert_to_with_sorting_impl( + std::shared_ptr &exec, Source *obj, bool skip_sorting) +{ + if (skip_sorting) { + return copy_and_convert_to(exec, obj); + } else { + using decay_dest = std::decay_t; + auto sorted_mtx = + detail::conversion_sort_helper::get_sorted_conversion( + exec, obj); + return {sorted_mtx.release(), std::default_delete()}; + } +} + +template +std::shared_ptr convert_to_with_sorting_impl( + std::shared_ptr &exec, std::shared_ptr obj, + bool skip_sorting) +{ + if (skip_sorting) { + return copy_and_convert_to(exec, obj); + } else { + using decay_dest = std::decay_t; + auto sorted_mtx = + detail::conversion_sort_helper::get_sorted_conversion( + exec, obj.get()); + return {std::move(sorted_mtx)}; + } +} + + +} // namespace detail + + +/** + * @internal + * + * Helper function that converts the given matrix to the Dest format with + * additional sorting if requested. + * + * If the given matrix was already sorted, is on the same executor and with a + * dynamic type of `Dest`, the same pointer is returned with an empty + * deleter. + * In all other cases, a new matrix is created, which stores the converted + * matrix. + * + * @tparam Dest the type to which the object should be converted + * @tparam Source the type of the source object + * + * @param exec the executor where the result should be placed + * @param obj the source object that should be converted + * @param skip_sorting indicator if the resulting matrix should be sorted or + * not + */ +template +std::unique_ptr> convert_to_with_sorting( + std::shared_ptr exec, Source *obj, bool skip_sorting) +{ + return detail::convert_to_with_sorting_impl(exec, obj, skip_sorting); +} + +/** + * @copydoc convert_to_with_sorting(std::shared_ptr, + * Source *, bool) + * + * @note This version adds the const qualifier for the result since the input is + * also const + */ +template +std::unique_ptr> +convert_to_with_sorting(std::shared_ptr exec, const Source *obj, + bool skip_sorting) +{ + return detail::convert_to_with_sorting_impl(exec, obj, + skip_sorting); +} + +/** + * @copydoc convert_to_with_sorting(std::shared_ptr, + * Source *, bool) + * + * @note This version has a unique_ptr as the source instead of a plain pointer + */ +template +std::unique_ptr> convert_to_with_sorting( + std::shared_ptr exec, const std::unique_ptr &obj, + bool skip_sorting) +{ + return detail::convert_to_with_sorting_impl(exec, obj.get(), + skip_sorting); +} + +/** + * @internal + * + * Helper function that converts the given matrix to the Dest format with + * additional sorting if requested. + * + * If the given matrix was already sorted, is on the same executor and with a + * dynamic type of `Dest`, the same pointer is returned. + * In all other cases, a new matrix is created, which stores the converted + * matrix. + * + * @tparam Dest the type to which the object should be converted + * @tparam Source the type of the source object + * + * @param exec the executor where the result should be placed + * @param obj the source object that should be converted + * @param skip_sorting indicator if the resulting matrix should be sorted or + * not + */ +template +std::shared_ptr convert_to_with_sorting( + std::shared_ptr exec, std::shared_ptr obj, + bool skip_sorting) +{ + return detail::convert_to_with_sorting_impl(exec, obj, skip_sorting); +} + +/** + * @copydoc convert_to_with_sorting(std::shared_ptr, + * std::shared_ptr, bool) + * + * @note This version adds the const qualifier for the result since the input is + * also const + */ +template +std::shared_ptr convert_to_with_sorting( + std::shared_ptr exec, std::shared_ptr obj, + bool skip_sorting) +{ + return detail::convert_to_with_sorting_impl(exec, obj, + skip_sorting); +} + + } // namespace gko -#endif // GKO_INTERNAL_CORE_BASE_UTILS_HPP_ \ No newline at end of file +#endif // GKO_INTERNAL_CORE_BASE_UTILS_HPP_ diff --git a/core/factorization/par_ict.cpp b/core/factorization/par_ict.cpp index a1a1408fb79..40d14a730d5 100644 --- a/core/factorization/par_ict.cpp +++ b/core/factorization/par_ict.cpp @@ -44,6 +44,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/base/utils.hpp" #include "core/factorization/factorization_kernels.hpp" #include "core/factorization/par_ict_kernels.hpp" #include "core/factorization/par_ilu_kernels.hpp" @@ -175,30 +176,14 @@ ParIct::generate_l_lt( const auto exec = this->get_executor(); // convert and/or sort the matrix if necessary - std::unique_ptr csr_system_matrix_unique_ptr{}; - auto csr_system_matrix = - dynamic_cast(system_matrix.get()); - if (csr_system_matrix == nullptr || - csr_system_matrix->get_executor() != exec) { - csr_system_matrix_unique_ptr = CsrMatrix::create(exec); - as>(system_matrix.get()) - ->convert_to(csr_system_matrix_unique_ptr.get()); - csr_system_matrix = csr_system_matrix_unique_ptr.get(); - } - if (!parameters_.skip_sorting) { - if (csr_system_matrix_unique_ptr == nullptr) { - csr_system_matrix_unique_ptr = CsrMatrix::create(exec); - csr_system_matrix_unique_ptr->copy_from(csr_system_matrix); - } - csr_system_matrix_unique_ptr->sort_by_column_index(); - csr_system_matrix = csr_system_matrix_unique_ptr.get(); - } + auto csr_system_matrix = convert_to_with_sorting( + exec, system_matrix, parameters_.skip_sorting); // initialize the L matrix data structures const auto num_rows = csr_system_matrix->get_size()[0]; Array l_row_ptrs_array{exec, num_rows + 1}; auto l_row_ptrs = l_row_ptrs_array.get_data(); - exec->run(make_initialize_row_ptrs_l(csr_system_matrix, l_row_ptrs)); + exec->run(make_initialize_row_ptrs_l(csr_system_matrix.get(), l_row_ptrs)); auto l_nnz = static_cast(exec->copy_val_to_host(l_row_ptrs + num_rows)); @@ -209,14 +194,14 @@ ParIct::generate_l_lt( std::move(l_row_ptrs_array)); // initialize L - exec->run(make_initialize_l(csr_system_matrix, l.get(), true)); + exec->run(make_initialize_l(csr_system_matrix.get(), l.get(), true)); // compute limit #nnz for L auto l_nnz_limit = static_cast(l_nnz * parameters_.fill_in_limit); ParIctState state{exec, - csr_system_matrix, + csr_system_matrix.get(), std::move(l), l_nnz_limit, parameters_.approximate_select, diff --git a/core/factorization/par_ilu.cpp b/core/factorization/par_ilu.cpp index d61a27747af..fc984415c41 100644 --- a/core/factorization/par_ilu.cpp +++ b/core/factorization/par_ilu.cpp @@ -83,10 +83,9 @@ ParIlu::generate_l_u( // Converts the system matrix to CSR. // Throws an exception if it is not convertible. - auto csr_system_matrix_unique_ptr = CsrMatrix::create(exec); + auto csr_system_matrix = CsrMatrix::create(exec); as>(system_matrix.get()) - ->convert_to(csr_system_matrix_unique_ptr.get()); - auto csr_system_matrix = csr_system_matrix_unique_ptr.get(); + ->convert_to(csr_system_matrix.get()); // If necessary, sort it if (!skip_sorting) { csr_system_matrix->sort_by_column_index(); @@ -94,14 +93,14 @@ ParIlu::generate_l_u( // Add explicit diagonal zero elements if they are missing exec->run(par_ilu_factorization::make_add_diagonal_elements( - csr_system_matrix, true)); + csr_system_matrix.get(), true)); const auto matrix_size = csr_system_matrix->get_size(); const auto number_rows = matrix_size[0]; Array l_row_ptrs{exec, number_rows + 1}; Array u_row_ptrs{exec, number_rows + 1}; exec->run(par_ilu_factorization::make_initialize_row_ptrs_l_u( - csr_system_matrix, l_row_ptrs.get_data(), u_row_ptrs.get_data())); + csr_system_matrix.get(), l_row_ptrs.get_data(), u_row_ptrs.get_data())); // Get nnz from device memory auto l_nnz = static_cast( @@ -123,7 +122,7 @@ ParIlu::generate_l_u( std::move(u_row_ptrs), u_strategy); exec->run(par_ilu_factorization::make_initialize_l_u( - csr_system_matrix, l_factor.get(), u_factor.get())); + csr_system_matrix.get(), l_factor.get(), u_factor.get())); // We use `transpose()` here to convert the Csr format to Csc. auto u_factor_transpose_lin_op = u_factor->transpose(); @@ -140,18 +139,10 @@ ParIlu::generate_l_u( // If it was not, and we already own a CSR `system_matrix`, // we can move the Csr matrix to Coo, which has very little overhead. - // Otherwise, we convert from the Csr matrix, since it is the conversion - // with the least overhead. - // We also have to convert / move from the CSR matrix if it was not already - // sorted (in which case we definitively own a CSR `system_matrix`). + // We also have to move from the CSR matrix if it was not already sorted. if (!skip_sorting || coo_system_matrix_ptr == nullptr) { coo_system_matrix_unique_ptr = CooMatrix::create(exec); - if (csr_system_matrix_unique_ptr == nullptr) { - csr_system_matrix->convert_to(coo_system_matrix_unique_ptr.get()); - } else { - csr_system_matrix_unique_ptr->move_to( - coo_system_matrix_unique_ptr.get()); - } + csr_system_matrix->move_to(coo_system_matrix_unique_ptr.get()); coo_system_matrix_ptr = coo_system_matrix_unique_ptr.get(); } diff --git a/core/factorization/par_ilut.cpp b/core/factorization/par_ilut.cpp index 1eb3dfeb950..a68a7d11c91 100644 --- a/core/factorization/par_ilut.cpp +++ b/core/factorization/par_ilut.cpp @@ -44,6 +44,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/base/utils.hpp" #include "core/factorization/factorization_kernels.hpp" #include "core/factorization/par_ilu_kernels.hpp" #include "core/factorization/par_ilut_kernels.hpp" @@ -191,24 +192,8 @@ ParIlut::generate_l_u( const auto exec = this->get_executor(); // convert and/or sort the matrix if necessary - std::unique_ptr csr_system_matrix_unique_ptr{}; - auto csr_system_matrix = - dynamic_cast(system_matrix.get()); - if (csr_system_matrix == nullptr || - csr_system_matrix->get_executor() != exec) { - csr_system_matrix_unique_ptr = CsrMatrix::create(exec); - as>(system_matrix.get()) - ->convert_to(csr_system_matrix_unique_ptr.get()); - csr_system_matrix = csr_system_matrix_unique_ptr.get(); - } - if (!parameters_.skip_sorting) { - if (csr_system_matrix_unique_ptr == nullptr) { - csr_system_matrix_unique_ptr = CsrMatrix::create(exec); - csr_system_matrix_unique_ptr->copy_from(csr_system_matrix); - } - csr_system_matrix_unique_ptr->sort_by_column_index(); - csr_system_matrix = csr_system_matrix_unique_ptr.get(); - } + auto csr_system_matrix = convert_to_with_sorting( + exec, system_matrix, parameters_.skip_sorting); // initialize the L and U matrix data structures const auto num_rows = csr_system_matrix->get_size()[0]; @@ -216,7 +201,7 @@ ParIlut::generate_l_u( Array u_row_ptrs_array{exec, num_rows + 1}; auto l_row_ptrs = l_row_ptrs_array.get_data(); auto u_row_ptrs = u_row_ptrs_array.get_data(); - exec->run(make_initialize_row_ptrs_l_u(csr_system_matrix, l_row_ptrs, + exec->run(make_initialize_row_ptrs_l_u(csr_system_matrix.get(), l_row_ptrs, u_row_ptrs)); auto l_nnz = @@ -233,7 +218,7 @@ ParIlut::generate_l_u( std::move(u_row_ptrs_array)); // initialize L and U - exec->run(make_initialize_l_u(csr_system_matrix, l.get(), u.get())); + exec->run(make_initialize_l_u(csr_system_matrix.get(), l.get(), u.get())); // compute limit #nnz for L and U auto l_nnz_limit = @@ -242,7 +227,7 @@ ParIlut::generate_l_u( static_cast(u_nnz * parameters_.fill_in_limit); ParIlutState state{exec, - csr_system_matrix, + csr_system_matrix.get(), std::move(l), std::move(u), l_nnz_limit, @@ -352,4 +337,4 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_PAR_ILUT); } // namespace factorization -} // namespace gko \ No newline at end of file +} // namespace gko diff --git a/core/preconditioner/isai.cpp b/core/preconditioner/isai.cpp index 0b8738c5594..108ca300acc 100644 --- a/core/preconditioner/isai.cpp +++ b/core/preconditioner/isai.cpp @@ -46,6 +46,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/base/utils.hpp" #include "core/preconditioner/isai_kernels.hpp" @@ -62,48 +63,6 @@ GKO_REGISTER_OPERATION(scatter_excess_solution, isai::scatter_excess_solution); } // namespace isai -/** - * @internal - * - * Helper function that converts the given matrix to the (const) CSR format with - * additional sorting. - * - * If the given matrix was already sorted, is on the same executor and with a - * dynamic type of `const Csr`, the same pointer is returned with an empty - * deleter. - * In all other cases, a new matrix is created, which stores the converted Csr - * matrix. - * If `skip_sorting` is false, the matrix will be sorted by column index, - * otherwise, it will not be sorted. - */ -template -std::shared_ptr convert_to_csr_and_sort( - std::shared_ptr &exec, std::shared_ptr mtx, - bool skip_sorting) -{ - static_assert( - std::is_same>::value, - "The given `Csr` type must be of type `matrix::Csr`!"); - if (skip_sorting && exec == mtx->get_executor()) { - auto csr_mtx = std::dynamic_pointer_cast(mtx); - if (csr_mtx) { - // Here, we can just forward the pointer with an empty deleter - // since it is already sorted and in the correct format - return csr_mtx; - } - } - auto copy = Csr::create(exec); - as>(mtx)->convert_to(lend(copy)); - // Here, we assume that a sorted matrix converted to CSR will also be - // sorted - if (!skip_sorting) { - copy->sort_by_column_index(); - } - return {std::move(copy)}; -} - - /** * @internal * @@ -156,7 +115,7 @@ void Isai::generate_inverse( using UpperTrs = solver::UpperTrs; GKO_ASSERT_IS_SQUARE_MATRIX(input); auto exec = this->get_executor(); - auto to_invert = convert_to_csr_and_sort(exec, input, skip_sorting); + auto to_invert = convert_to_with_sorting(exec, input, skip_sorting); auto inverted = extend_sparsity(exec, to_invert, power); auto num_rows = inverted->get_size()[0]; auto is_lower = IsaiType == isai_type::lower; diff --git a/core/preconditioner/jacobi.cpp b/core/preconditioner/jacobi.cpp index 6d52c5b4838..b59a6289a46 100644 --- a/core/preconditioner/jacobi.cpp +++ b/core/preconditioner/jacobi.cpp @@ -45,6 +45,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/base/extended_float.hpp" +#include "core/base/utils.hpp" #include "core/preconditioner/jacobi_kernels.hpp" #include "core/preconditioner/jacobi_utils.hpp" @@ -210,19 +211,8 @@ void Jacobi::generate(const LinOp *system_matrix, GKO_ASSERT_IS_SQUARE_MATRIX(system_matrix); using csr_type = matrix::Csr; const auto exec = this->get_executor(); - decltype(copy_and_convert_to(exec, system_matrix)) csr_mtx{}; - - if (skip_sorting) { - csr_mtx = copy_and_convert_to>( - exec, system_matrix); - } else { - auto editable_csr = csr_type::create(exec); - as>(system_matrix) - ->convert_to(lend(editable_csr)); - editable_csr->sort_by_column_index(); - csr_mtx = decltype(csr_mtx){editable_csr.release(), - std::default_delete{}}; - } + auto csr_mtx = + convert_to_with_sorting(exec, system_matrix, skip_sorting); if (parameters_.block_pointers.get_data() == nullptr) { this->detect_blocks(csr_mtx.get()); diff --git a/reference/test/base/CMakeLists.txt b/reference/test/base/CMakeLists.txt index 3b86589507e..3386bb01e20 100644 --- a/reference/test/base/CMakeLists.txt +++ b/reference/test/base/CMakeLists.txt @@ -1,3 +1,4 @@ ginkgo_create_test(combination) ginkgo_create_test(composition) ginkgo_create_test(perturbation) +ginkgo_create_test(utils) diff --git a/reference/test/base/utils.cpp b/reference/test/base/utils.cpp new file mode 100644 index 00000000000..6a4b7d4b9bb --- /dev/null +++ b/reference/test/base/utils.cpp @@ -0,0 +1,187 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#include "core/base/utils.hpp" + + +#include + + +#include + + +#include +#include +#include +#include +#include +#include +#include + + +#include "core/test/utils.hpp" +#include "core/test/utils/unsort_matrix.hpp" + + +namespace { + + +class ConvertToWithSorting : public ::testing::Test { +protected: + using value_type = double; + using index_type = gko::int32; + using Dense = gko::matrix::Dense; + using Csr = gko::matrix::Csr; + using Coo = gko::matrix::Coo; + + ConvertToWithSorting() + : ref{gko::ReferenceExecutor::create()}, + mtx{gko::initialize({{1, 2, 3}, {6, 0, 7}, {-1, 8, 0}}, ref)}, + unsorted_coo{Coo::create(ref, gko::dim<2>{3, 3}, + I{1, 3, 2, 7, 6, -1, 8}, + I{0, 2, 1, 2, 0, 0, 1}, + I{0, 0, 0, 1, 1, 2, 2})}, + unsorted_csr{Csr::create( + ref, gko::dim<2>{3, 3}, I{1, 3, 2, 7, 6, -1, 8}, + I{0, 2, 1, 2, 0, 0, 1}, I{0, 3, 5, 7})} + + {} + + std::shared_ptr ref; + std::unique_ptr mtx; + std::unique_ptr unsorted_coo; + std::unique_ptr unsorted_csr; +}; + + +TEST_F(ConvertToWithSorting, SortWithUniquePtr) +{ + auto result = gko::convert_to_with_sorting(ref, unsorted_coo, false); + + ASSERT_TRUE(result->is_sorted_by_column_index()); + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, DontSortWithUniquePtr) +{ + auto result = gko::convert_to_with_sorting(ref, unsorted_csr, true); + + ASSERT_EQ(result.get(), unsorted_csr.get()); + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, SortWithSharedPtr) +{ + std::shared_ptr shared = gko::share(unsorted_csr); + + auto result = gko::convert_to_with_sorting(ref, shared, false); + + ASSERT_TRUE(result->is_sorted_by_column_index()); + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, DontSortWithSharedPtr) +{ + std::shared_ptr shared = gko::share(unsorted_csr); + + auto result = gko::convert_to_with_sorting(ref, shared, true); + + ASSERT_EQ(result.get(), shared.get()); + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, SortWithSharedConstPtr) +{ + std::shared_ptr shared = gko::share(unsorted_coo); + + auto result = gko::convert_to_with_sorting(ref, shared, false); + + ASSERT_TRUE(result->is_sorted_by_column_index()); + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, DontSortWithSharedConstPtr) +{ + std::shared_ptr shared = gko::share(unsorted_coo); + + auto result = gko::convert_to_with_sorting(ref, shared, true); + + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, SortWithRawPtr) +{ + auto result = + gko::convert_to_with_sorting(ref, unsorted_coo.get(), false); + + ASSERT_TRUE(result->is_sorted_by_column_index()); + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, DontSortWithRawPtr) +{ + auto result = + gko::convert_to_with_sorting(ref, unsorted_coo.get(), true); + + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, SortWithConstRawPtr) +{ + const Coo *cptr = unsorted_coo.get(); + + auto result = gko::convert_to_with_sorting(ref, cptr, false); + + ASSERT_TRUE(result->is_sorted_by_column_index()); + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +TEST_F(ConvertToWithSorting, DontSortWithConstRawPtr) +{ + const auto cptr = mtx.get(); + + auto result = gko::convert_to_with_sorting(ref, cptr, true); + + GKO_ASSERT_MTX_NEAR(result.get(), mtx.get(), 0.); +} + + +} // namespace