Skip to content

Commit

Permalink
Merge factorization unpack functionality
Browse files Browse the repository at this point in the history
This allows unpacking combined representations of LU and Cholesky factorizations into their factors.

Related PR: #1432
  • Loading branch information
upsj committed Oct 16, 2023
2 parents 2f3720f + 22ae5bb commit 9f71bcd
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 2 deletions.
74 changes: 72 additions & 2 deletions core/factorization/factorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,88 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/factorization/factorization.hpp>


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/matrix/csr.hpp>


#include "core/factorization/factorization_kernels.hpp"


namespace gko {
namespace experimental {
namespace factorization {
namespace {


GKO_REGISTER_OPERATION(initialize_row_ptrs_l_u,
factorization::initialize_row_ptrs_l_u);
GKO_REGISTER_OPERATION(initialize_l_u, factorization::initialize_l_u);
GKO_REGISTER_OPERATION(initialize_row_ptrs_l,
factorization::initialize_row_ptrs_l);
GKO_REGISTER_OPERATION(initialize_l, factorization::initialize_l);


} // namespace


template <typename ValueType, typename IndexType>
std::unique_ptr<Factorization<ValueType, IndexType>>
Factorization<ValueType, IndexType>::unpack() const GKO_NOT_IMPLEMENTED;
Factorization<ValueType, IndexType>::unpack() const
{
const auto exec = this->get_executor();
const auto size = this->get_size();
switch (this->get_storage_type()) {
case storage_type::empty:
GKO_NOT_SUPPORTED(nullptr);
case storage_type::composition:
case storage_type::symm_composition:
return this->clone();
case storage_type::combined_lu: {
// count nonzeros
array<index_type> l_row_ptrs{exec, size[0] + 1};
array<index_type> u_row_ptrs{exec, size[0] + 1};
const auto mtx = this->get_combined();
exec->run(make_initialize_row_ptrs_l_u(mtx.get(), l_row_ptrs.get_data(),
u_row_ptrs.get_data()));
const auto l_nnz = static_cast<size_type>(
exec->copy_val_to_host(l_row_ptrs.get_const_data() + size[0]));
const auto u_nnz = static_cast<size_type>(
exec->copy_val_to_host(u_row_ptrs.get_const_data() + size[0]));
// create matrices
auto l_mtx = matrix_type::create(
exec, size, array<value_type>{exec, l_nnz},
array<index_type>{exec, l_nnz}, std::move(l_row_ptrs));
auto u_mtx = matrix_type::create(
exec, size, array<value_type>{exec, u_nnz},
array<index_type>{exec, u_nnz}, std::move(u_row_ptrs));
// fill matrices
exec->run(make_initialize_l_u(mtx.get(), l_mtx.get(), u_mtx.get()));
return create_from_composition(
composition_type::create(std::move(l_mtx), std::move(u_mtx)));
}
case storage_type::symm_combined_cholesky: {
// count nonzeros
array<index_type> l_row_ptrs{exec, size[0] + 1};
const auto mtx = this->get_combined();
exec->run(make_initialize_row_ptrs_l(mtx.get(), l_row_ptrs.get_data()));
const auto l_nnz = static_cast<size_type>(
exec->copy_val_to_host(l_row_ptrs.get_const_data() + size[0]));
// create matrices
auto l_mtx = matrix_type::create(
exec, size, array<value_type>{exec, l_nnz},
array<index_type>{exec, l_nnz}, std::move(l_row_ptrs));
// fill matrices
exec->run(make_initialize_l(mtx.get(), l_mtx.get(), false));
auto u_mtx = l_mtx->conj_transpose();
return create_from_symm_composition(
composition_type::create(std::move(l_mtx), std::move(u_mtx)));
}
case storage_type::combined_ldu:
case storage_type::symm_combined_ldl:
GKO_NOT_IMPLEMENTED;
}
}


template <typename ValueType, typename IndexType>
Expand All @@ -58,7 +128,7 @@ template <typename ValueType, typename IndexType>
std::shared_ptr<const gko::matrix::Csr<ValueType, IndexType>>
Factorization<ValueType, IndexType>::get_lower_factor() const
{
switch (storage_type_) {
switch (this->get_storage_type()) {
case storage_type::composition:
case storage_type::symm_composition:
GKO_ASSERT(factors_->get_operators().size() == 2 ||
Expand Down
87 changes: 87 additions & 0 deletions reference/test/factorization/factorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ class Factorization : public ::testing::Test {
: ref(gko::ReferenceExecutor::create()),
lower_mtx{gko::initialize<matrix_type>(
{{1.0, 0.0, 0.0}, {3.0, 1.0, 0.0}, {1.0, 2.0, 1.0}}, ref)},
lower_cholesky_mtx{gko::initialize<matrix_type>(
{{1.0, 0.0, 0.0}, {3.0, -1.0, 0.0}, {1.0, 2.0, 5.0}}, ref)},
diagonal{diag_type::create(ref, 3)},
upper_mtx(gko::initialize<matrix_type>(
{{1.0, 2.0, 1.0}, {0.0, 1.0, 3.0}, {0.0, 0.0, 1.0}}, ref)),
upper_nonunit_mtx(gko::initialize<matrix_type>(
{{1.0, 2.0, 1.0}, {0.0, -1.0, 3.0}, {0.0, 0.0, 5.0}}, ref)),
combined_mtx(gko::initialize<matrix_type>(
{{1.0, 2.0, 1.0}, {3.0, -1.0, 3.0}, {1.0, 2.0, 5.0}}, ref)),
input(gko::initialize<vector_type>({1.0, 2.0, 3.0}, ref)),
Expand All @@ -88,8 +92,10 @@ class Factorization : public ::testing::Test {

std::shared_ptr<const gko::ReferenceExecutor> ref;
std::shared_ptr<matrix_type> lower_mtx;
std::shared_ptr<matrix_type> lower_cholesky_mtx;
std::shared_ptr<diag_type> diagonal;
std::shared_ptr<matrix_type> upper_mtx;
std::shared_ptr<matrix_type> upper_nonunit_mtx;
std::shared_ptr<matrix_type> combined_mtx;
std::shared_ptr<vector_type> input;
std::shared_ptr<vector_type> output;
Expand Down Expand Up @@ -261,6 +267,87 @@ TYPED_TEST(Factorization, CreateSymmCombinedLDLWorks)
}


TYPED_TEST(Factorization, UnpackCombinedLUWorks)
{
using factorization_type = typename TestFixture::factorization_type;
auto fact = factorization_type::create_from_combined_lu(
this->combined_mtx->clone());

auto separated = fact->unpack();

ASSERT_EQ(separated->get_storage_type(),
gko::experimental::factorization::storage_type::composition);
ASSERT_EQ(separated->get_combined(), nullptr);
ASSERT_EQ(separated->get_diagonal(), nullptr);
GKO_ASSERT_MTX_NEAR(separated->get_lower_factor(), this->lower_mtx, 0.0);
GKO_ASSERT_MTX_NEAR(separated->get_upper_factor(), this->upper_nonunit_mtx,
0.0);
}


TYPED_TEST(Factorization, UnpackSymmCombinedCholeskyWorks)
{
using matrix_type = typename TestFixture::matrix_type;
using factorization_type = typename TestFixture::factorization_type;
auto fact = factorization_type::create_from_combined_cholesky(
this->combined_mtx->clone());

auto separated = fact->unpack();

ASSERT_EQ(separated->get_storage_type(),
gko::experimental::factorization::storage_type::symm_composition);
ASSERT_EQ(separated->get_combined(), nullptr);
ASSERT_EQ(separated->get_diagonal(), nullptr);
GKO_ASSERT_MTX_NEAR(separated->get_lower_factor(), this->lower_cholesky_mtx,
0.0);
GKO_ASSERT_MTX_NEAR(
separated->get_upper_factor(),
gko::as<matrix_type>(this->lower_cholesky_mtx->conj_transpose()), 0.0);
}


TYPED_TEST(Factorization, UnpackCompositionWorks)
{
using factorization_type = typename TestFixture::factorization_type;
using composition_type = typename TestFixture::composition_type;
auto fact = factorization_type::create_from_composition(
composition_type::create(this->lower_mtx, this->upper_nonunit_mtx));

auto separated = fact->unpack();

ASSERT_EQ(separated->get_storage_type(),
gko::experimental::factorization::storage_type::composition);
ASSERT_EQ(separated->get_combined(), nullptr);
ASSERT_EQ(separated->get_diagonal(), nullptr);
GKO_ASSERT_MTX_NEAR(separated->get_lower_factor(), this->lower_mtx, 0.0);
GKO_ASSERT_MTX_NEAR(separated->get_upper_factor(), this->upper_nonunit_mtx,
0.0);
}


TYPED_TEST(Factorization, UnpackSymmCompositionWorks)
{
using matrix_type = typename TestFixture::matrix_type;
using factorization_type = typename TestFixture::factorization_type;
using composition_type = typename TestFixture::composition_type;
auto fact = factorization_type::create_from_symm_composition(
composition_type::create(this->lower_cholesky_mtx,
this->lower_cholesky_mtx->conj_transpose()));

auto separated = fact->unpack();

ASSERT_EQ(separated->get_storage_type(),
gko::experimental::factorization::storage_type::symm_composition);
ASSERT_EQ(separated->get_combined(), nullptr);
ASSERT_EQ(separated->get_diagonal(), nullptr);
GKO_ASSERT_MTX_NEAR(separated->get_lower_factor(), this->lower_cholesky_mtx,
0.0);
GKO_ASSERT_MTX_NEAR(
separated->get_upper_factor(),
gko::as<matrix_type>(this->lower_cholesky_mtx->conj_transpose()), 0.0);
}


TYPED_TEST(Factorization, ApplyFromCompositionWorks)
{
using factorization_type = typename TestFixture::factorization_type;
Expand Down

0 comments on commit 9f71bcd

Please sign in to comment.