From 1fc43ceb35a2d849afc5ee287e550ea80dcca4eb Mon Sep 17 00:00:00 2001 From: Gregor Olenik Date: Fri, 13 Oct 2023 22:09:37 +0200 Subject: [PATCH 1/8] Add pregenerated local solver as factory param --- core/distributed/preconditioner/schwarz.cpp | 13 +++++++++++-- .../core/distributed/preconditioner/schwarz.hpp | 5 +++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/core/distributed/preconditioner/schwarz.cpp b/core/distributed/preconditioner/schwarz.cpp index 0d1267bc0b4..2b2c33d23e7 100644 --- a/core/distributed/preconditioner/schwarz.cpp +++ b/core/distributed/preconditioner/schwarz.cpp @@ -102,14 +102,23 @@ template void Schwarz::generate( std::shared_ptr system_matrix) { - if (parameters_.local_solver) { + if (parameters_.local_solver && !parameters_.generated_local_solvers) { this->local_solver_ = parameters_.local_solver->generate( as>( system_matrix) ->get_local_matrix()); + } else if (parameters_.generated_local_solvers && + !parameters_.local_solver) { + this->local_solver_ = parameters_.generated_local_solvers; + } else if (!parameters_.generated_local_ && !parameters_.local_solver) { + throw ::gko::InvalidStateError( + __FILE__, __LINE__, __func__, + "Requires either a generated solver or an solver factory"); } else { - GKO_NOT_IMPLEMENTED; + throw ::gko::InvalidStateError( + __FILE__, __LINE__, __func__, + "Provided both a generated solver and a solver factory"); } } diff --git a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp index f31bd96aa2e..5bce97fb414 100644 --- a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp +++ b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp @@ -95,6 +95,11 @@ class Schwarz * Local solver factory. */ GKO_DEFERRED_FACTORY_PARAMETER(local_solver, LinOpFactory); + /** + * Generated Inner solvers. + */ + std::shared_ptr GKO_FACTORY_PARAMETER( + generated_local_solver, nullptr); }; GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); From f6f01696cd66db5f2203bd162000ca6ffd7e2bf0 Mon Sep 17 00:00:00 2001 From: Gregor Olenik Date: Fri, 13 Oct 2023 23:12:57 +0200 Subject: [PATCH 2/8] Add unit test --- test/mpi/preconditioner/schwarz.cpp | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/mpi/preconditioner/schwarz.cpp b/test/mpi/preconditioner/schwarz.cpp index 8d07ba44046..f3269b1d237 100644 --- a/test/mpi/preconditioner/schwarz.cpp +++ b/test/mpi/preconditioner/schwarz.cpp @@ -217,6 +217,36 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolver) this->non_dist_x); } +TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) +{ + using value_type = typename TestFixture::value_type; + using csr = typename TestFixture::local_matrix_type; + using cg = typename TestFixture::solver_type; + using prec = typename TestFixture::dist_prec_type; + constexpr double tolerance = 1e-20; + auto iter_stop = gko::share( + gko::stop::Iteration::build().with_max_iters(200u).on(this->exec)); + auto tol_stop = gko::share( + gko::stop::ResidualNorm::build() + .with_reduction_factor( + static_cast>(tolerance)) + .on(this->exec)); + this->non_dist_solver_factory = + cg::build() + .with_preconditioner(this->local_solver_factory) + .with_criteria(iter_stop, tol_stop) + .on(this->exec); + auto local_solver = + this->non_dist_solver_factory->generate(this->non_dist_mat); + this->dist_solver_factory = + cg::build() + .with_preconditioner(prec::build() + .with_generated_local_solver(local_solver) + .on(this->exec)) + .with_criteria(iter_stop, tol_stop) + .on(this->exec); +} + TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditioner) { From 174c3c82992727c6fab225c7fa0da42d6b2e96dd Mon Sep 17 00:00:00 2001 From: Gregor Olenik Date: Mon, 16 Oct 2023 11:02:48 +0200 Subject: [PATCH 3/8] Test if generate fails for invalid solver states --- core/distributed/preconditioner/schwarz.cpp | 9 +-- test/mpi/preconditioner/schwarz.cpp | 69 +++++++++++++-------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/core/distributed/preconditioner/schwarz.cpp b/core/distributed/preconditioner/schwarz.cpp index 2b2c33d23e7..90adf384cce 100644 --- a/core/distributed/preconditioner/schwarz.cpp +++ b/core/distributed/preconditioner/schwarz.cpp @@ -102,16 +102,17 @@ template void Schwarz::generate( std::shared_ptr system_matrix) { - if (parameters_.local_solver && !parameters_.generated_local_solvers) { + if (parameters_.local_solver && !parameters_.generated_local_solver) { this->local_solver_ = parameters_.local_solver->generate( as>( system_matrix) ->get_local_matrix()); - } else if (parameters_.generated_local_solvers && + } else if (parameters_.generated_local_solver && + !parameters_.local_solver) { + this->local_solver_ = parameters_.generated_local_solver; + } else if (!parameters_.generated_local_solver && !parameters_.local_solver) { - this->local_solver_ = parameters_.generated_local_solvers; - } else if (!parameters_.generated_local_ && !parameters_.local_solver) { throw ::gko::InvalidStateError( __FILE__, __LINE__, __func__, "Requires either a generated solver or an solver factory"); diff --git a/test/mpi/preconditioner/schwarz.cpp b/test/mpi/preconditioner/schwarz.cpp index f3269b1d237..7a1f69a59a3 100644 --- a/test/mpi/preconditioner/schwarz.cpp +++ b/test/mpi/preconditioner/schwarz.cpp @@ -217,37 +217,56 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolver) this->non_dist_x); } -TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) + +TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfPregenSolverAndSolverFactoryArePresent) { - using value_type = typename TestFixture::value_type; - using csr = typename TestFixture::local_matrix_type; - using cg = typename TestFixture::solver_type; using prec = typename TestFixture::dist_prec_type; - constexpr double tolerance = 1e-20; - auto iter_stop = gko::share( - gko::stop::Iteration::build().with_max_iters(200u).on(this->exec)); - auto tol_stop = gko::share( - gko::stop::ResidualNorm::build() - .with_reduction_factor( - static_cast>(tolerance)) - .on(this->exec)); - this->non_dist_solver_factory = - cg::build() - .with_preconditioner(this->local_solver_factory) - .with_criteria(iter_stop, tol_stop) - .on(this->exec); auto local_solver = - this->non_dist_solver_factory->generate(this->non_dist_mat); - this->dist_solver_factory = - cg::build() - .with_preconditioner(prec::build() - .with_generated_local_solver(local_solver) - .on(this->exec)) - .with_criteria(iter_stop, tol_stop) - .on(this->exec); + gko::share(this->non_dist_solver_factory->generate(this->non_dist_mat)); + + auto schwarz = prec::build() + .with_local_solver(this->local_solver_factory) + .with_generated_local_solver(local_solver) + .on(this->exec); + + ASSERT_THROW(schwarz->generate(this->dist_mat), gko::InvalidStateError); + + auto schwarz_no_solver = prec::build().on(this->exec); + ASSERT_THROW(schwarz_no_solver->generate(this->dist_mat), gko::InvalidStateError); } +// TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) +// { +// using value_type = typename TestFixture::value_type; +// using csr = typename TestFixture::local_matrix_type; +// using cg = typename TestFixture::solver_type; +// using prec = typename TestFixture::dist_prec_type; +// constexpr double tolerance = 1e-20; +// auto iter_stop = gko::share( +// gko::stop::Iteration::build().with_max_iters(200u).on(this->exec)); +// auto tol_stop = gko::share( +// gko::stop::ResidualNorm::build() +// .with_reduction_factor( +// static_cast>(tolerance)) +// .on(this->exec)); +// this->non_dist_solver_factory = +// cg::build() +// .with_preconditioner(this->local_solver_factory) +// .with_criteria(iter_stop, tol_stop) +// .on(this->exec); +// auto local_solver = +// this->non_dist_solver_factory->generate(this->non_dist_mat); +// this->dist_solver_factory = +// cg::build() +// .with_preconditioner(prec::build() +// .with_generated_local_solver(local_solver.get()) +// .on(this->exec)) +// .with_criteria(iter_stop, tol_stop) +// .on(this->exec); +// } + + TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditioner) { using value_type = typename TestFixture::value_type; From 5799477e3e0bf54689c1026b1c5473e1c55ed114 Mon Sep 17 00:00:00 2001 From: Gregor Olenik Date: Tue, 17 Oct 2023 14:58:08 +0200 Subject: [PATCH 4/8] refactor build method a bit, add unit tests --- core/distributed/preconditioner/schwarz.cpp | 28 +++-- test/mpi/preconditioner/schwarz.cpp | 121 +++++++++++--------- 2 files changed, 81 insertions(+), 68 deletions(-) diff --git a/core/distributed/preconditioner/schwarz.cpp b/core/distributed/preconditioner/schwarz.cpp index 90adf384cce..7dfdfd3b4a7 100644 --- a/core/distributed/preconditioner/schwarz.cpp +++ b/core/distributed/preconditioner/schwarz.cpp @@ -102,24 +102,28 @@ template void Schwarz::generate( std::shared_ptr system_matrix) { - if (parameters_.local_solver && !parameters_.generated_local_solver) { + if (parameters_.local_solver != nullptr && + parameters_.generated_local_solver != nullptr) { + throw ::gko::InvalidStateError( + __FILE__, __LINE__, __func__, + "Provided both a generated solver and a solver factory"); + } + + if (parameters_.local_solver == nullptr && + parameters_.generated_local_solver == nullptr) { + throw ::gko::InvalidStateError( + __FILE__, __LINE__, __func__, + "Requires either a generated solver or an solver factory"); + } + + if (parameters_.local_solver) { this->local_solver_ = parameters_.local_solver->generate( as>( system_matrix) ->get_local_matrix()); - } else if (parameters_.generated_local_solver && - !parameters_.local_solver) { - this->local_solver_ = parameters_.generated_local_solver; - } else if (!parameters_.generated_local_solver && - !parameters_.local_solver) { - throw ::gko::InvalidStateError( - __FILE__, __LINE__, __func__, - "Requires either a generated solver or an solver factory"); } else { - throw ::gko::InvalidStateError( - __FILE__, __LINE__, __func__, - "Provided both a generated solver and a solver factory"); + this->local_solver_ = parameters_.generated_local_solver; } } diff --git a/test/mpi/preconditioner/schwarz.cpp b/test/mpi/preconditioner/schwarz.cpp index 7a1f69a59a3..42a043d2e51 100644 --- a/test/mpi/preconditioner/schwarz.cpp +++ b/test/mpi/preconditioner/schwarz.cpp @@ -178,65 +178,32 @@ class SchwarzPreconditioner : public CommonMpiTestFixture { TYPED_TEST_SUITE(SchwarzPreconditioner, gko::test::ValueLocalGlobalIndexTypes, TupleTypenameNameGenerator); - -TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolver) +TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfInvalidState) { using value_type = typename TestFixture::value_type; - using csr = typename TestFixture::local_matrix_type; - using cg = typename TestFixture::solver_type; - using prec = typename TestFixture::dist_prec_type; - constexpr double tolerance = 1e-20; - auto iter_stop = gko::share( - gko::stop::Iteration::build().with_max_iters(200u).on(this->exec)); - auto tol_stop = gko::share( - gko::stop::ResidualNorm::build() - .with_reduction_factor( - static_cast>(tolerance)) - .on(this->exec)); - this->dist_solver_factory = - cg::build() - .with_preconditioner( - prec::build() - .with_local_solver(this->local_solver_factory) - .on(this->exec)) - .with_criteria(iter_stop, tol_stop) - .on(this->exec); - auto dist_solver = this->dist_solver_factory->generate(this->dist_mat); - this->non_dist_solver_factory = - cg::build() - .with_preconditioner(this->local_solver_factory) - .with_criteria(iter_stop, tol_stop) - .on(this->exec); - auto non_dist_solver = - this->non_dist_solver_factory->generate(this->non_dist_mat); - - dist_solver->apply(this->dist_b.get(), this->dist_x.get()); - non_dist_solver->apply(this->non_dist_b.get(), this->non_dist_x.get()); - - this->assert_equal_to_non_distributed_vector(this->dist_x, - this->non_dist_x); -} - - -TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfPregenSolverAndSolverFactoryArePresent) -{ + using local_index_type = typename TestFixture::local_index_type; + using local_prec_type = + gko::preconditioner::Jacobi; using prec = typename TestFixture::dist_prec_type; - auto local_solver = - gko::share(this->non_dist_solver_factory->generate(this->non_dist_mat)); + auto local_solver = gko::share(local_prec_type::build() + .with_max_block_size(1u) + .on(this->exec) + ->generate(this->non_dist_mat)); auto schwarz = prec::build() - .with_local_solver(this->local_solver_factory) - .with_generated_local_solver(local_solver) - .on(this->exec); + .with_local_solver(this->local_solver_factory) + .with_generated_local_solver(local_solver) + .on(this->exec); ASSERT_THROW(schwarz->generate(this->dist_mat), gko::InvalidStateError); auto schwarz_no_solver = prec::build().on(this->exec); - ASSERT_THROW(schwarz_no_solver->generate(this->dist_mat), gko::InvalidStateError); + ASSERT_THROW(schwarz_no_solver->generate(this->dist_mat), + gko::InvalidStateError); } -// TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) +// TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolver) // { // using value_type = typename TestFixture::value_type; // using csr = typename TestFixture::local_matrix_type; @@ -250,30 +217,72 @@ TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfPregenSolverAndSolverFactoryAre // .with_reduction_factor( // static_cast>(tolerance)) // .on(this->exec)); -// this->non_dist_solver_factory = +// this->dist_solver_factory = // cg::build() -// .with_preconditioner(this->local_solver_factory) +// .with_preconditioner( +// prec::build() +// .with_local_solver(this->local_solver_factory) +// .on(this->exec)) // .with_criteria(iter_stop, tol_stop) // .on(this->exec); -// auto local_solver = -// this->non_dist_solver_factory->generate(this->non_dist_mat); -// this->dist_solver_factory = +// auto dist_solver = this->dist_solver_factory->generate(this->dist_mat); +// this->non_dist_solver_factory = // cg::build() -// .with_preconditioner(prec::build() -// .with_generated_local_solver(local_solver.get()) -// .on(this->exec)) +// .with_preconditioner(this->local_solver_factory) // .with_criteria(iter_stop, tol_stop) // .on(this->exec); +// auto non_dist_solver = +// this->non_dist_solver_factory->generate(this->non_dist_mat); +// +// dist_solver->apply(this->dist_b.get(), this->dist_x.get()); +// dist_solver->apply(this->non_dist_b.get(), this->non_dist_x.get()); +// +// this->assert_equal_to_non_distributed_vector(this->dist_x, +// this->non_dist_x); // } -TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditioner) +TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) { using value_type = typename TestFixture::value_type; + using local_index_type = typename TestFixture::local_index_type; + using local_prec_type = + gko::preconditioner::Jacobi; using csr = typename TestFixture::local_matrix_type; using cg = typename TestFixture::solver_type; using prec = typename TestFixture::dist_prec_type; + auto local_solver = gko::share(local_prec_type::build() + .with_max_block_size(1u) + .on(this->exec) + ->generate(this->non_dist_mat)); + auto precond = prec::build() + .with_local_solver(this->local_solver_factory) + .on(this->exec) + ->generate(this->dist_mat); + + auto precond_pregen = prec::build() + .with_generated_local_solver(local_solver) + .on(this->exec) + ->generate(this->dist_mat); + + auto dist_x = gko::share(this->dist_x->clone()); + auto dist_x_pregen = gko::share(this->dist_x->clone()); + + precond->apply(this->dist_b.get(), dist_x.get()); + precond->apply(this->dist_b.get(), dist_x_pregen.get()); + + GKO_ASSERT_MTX_NEAR( + dist_x->get_local_vector(), dist_x_pregen->get_local_vector(), + r::value); +} + + +TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditioner) +{ + using value_type = typename TestFixture::value_type; + using prec = typename TestFixture::dist_prec_type; + auto precond_factory = prec::build() .with_local_solver(this->local_solver_factory) .on(this->exec); From c8a4c18faaac2ccf628118c31b49d234be65d67d Mon Sep 17 00:00:00 2001 From: Gregor Olenik Date: Tue, 17 Oct 2023 16:20:14 +0200 Subject: [PATCH 5/8] add missing test --- test/mpi/preconditioner/schwarz.cpp | 74 ++++++++++++++--------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/test/mpi/preconditioner/schwarz.cpp b/test/mpi/preconditioner/schwarz.cpp index 42a043d2e51..2241be8f535 100644 --- a/test/mpi/preconditioner/schwarz.cpp +++ b/test/mpi/preconditioner/schwarz.cpp @@ -203,43 +203,43 @@ TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfInvalidState) } -// TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolver) -// { -// using value_type = typename TestFixture::value_type; -// using csr = typename TestFixture::local_matrix_type; -// using cg = typename TestFixture::solver_type; -// using prec = typename TestFixture::dist_prec_type; -// constexpr double tolerance = 1e-20; -// auto iter_stop = gko::share( -// gko::stop::Iteration::build().with_max_iters(200u).on(this->exec)); -// auto tol_stop = gko::share( -// gko::stop::ResidualNorm::build() -// .with_reduction_factor( -// static_cast>(tolerance)) -// .on(this->exec)); -// this->dist_solver_factory = -// cg::build() -// .with_preconditioner( -// prec::build() -// .with_local_solver(this->local_solver_factory) -// .on(this->exec)) -// .with_criteria(iter_stop, tol_stop) -// .on(this->exec); -// auto dist_solver = this->dist_solver_factory->generate(this->dist_mat); -// this->non_dist_solver_factory = -// cg::build() -// .with_preconditioner(this->local_solver_factory) -// .with_criteria(iter_stop, tol_stop) -// .on(this->exec); -// auto non_dist_solver = -// this->non_dist_solver_factory->generate(this->non_dist_mat); -// -// dist_solver->apply(this->dist_b.get(), this->dist_x.get()); -// dist_solver->apply(this->non_dist_b.get(), this->non_dist_x.get()); -// -// this->assert_equal_to_non_distributed_vector(this->dist_x, -// this->non_dist_x); -// } +TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolver) +{ + using value_type = typename TestFixture::value_type; + using csr = typename TestFixture::local_matrix_type; + using cg = typename TestFixture::solver_type; + using prec = typename TestFixture::dist_prec_type; + constexpr double tolerance = 1e-20; + auto iter_stop = gko::share( + gko::stop::Iteration::build().with_max_iters(200u).on(this->exec)); + auto tol_stop = gko::share( + gko::stop::ResidualNorm::build() + .with_reduction_factor( + static_cast>(tolerance)) + .on(this->exec)); + this->dist_solver_factory = + cg::build() + .with_preconditioner( + prec::build() + .with_local_solver(this->local_solver_factory) + .on(this->exec)) + .with_criteria(iter_stop, tol_stop) + .on(this->exec); + auto dist_solver = this->dist_solver_factory->generate(this->dist_mat); + this->non_dist_solver_factory = + cg::build() + .with_preconditioner(this->local_solver_factory) + .with_criteria(iter_stop, tol_stop) + .on(this->exec); + auto non_dist_solver = + this->non_dist_solver_factory->generate(this->non_dist_mat); + + dist_solver->apply(this->dist_b.get(), this->dist_x.get()); + non_dist_solver->apply(this->non_dist_b.get(), this->non_dist_x.get()); + + this->assert_equal_to_non_distributed_vector(this->dist_x, + this->non_dist_x); +} TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) From 853aa91fceca957df4816d7130c4c4d86f3de53a Mon Sep 17 00:00:00 2001 From: Gregor Olenik Date: Wed, 18 Oct 2023 10:21:46 +0200 Subject: [PATCH 6/8] Implement review comments Co-authored-by: Pratik Nayak --- core/distributed/preconditioner/schwarz.cpp | 12 ++++-------- .../core/distributed/preconditioner/schwarz.hpp | 1 + test/mpi/preconditioner/schwarz.cpp | 13 ++++++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/core/distributed/preconditioner/schwarz.cpp b/core/distributed/preconditioner/schwarz.cpp index 7dfdfd3b4a7..dd3f86a1cd9 100644 --- a/core/distributed/preconditioner/schwarz.cpp +++ b/core/distributed/preconditioner/schwarz.cpp @@ -102,17 +102,13 @@ template void Schwarz::generate( std::shared_ptr system_matrix) { - if (parameters_.local_solver != nullptr && - parameters_.generated_local_solver != nullptr) { - throw ::gko::InvalidStateError( - __FILE__, __LINE__, __func__, + if (parameters_.local_solver && parameters_.generated_local_solver) { + GKO_INVALID_STATE( "Provided both a generated solver and a solver factory"); } - if (parameters_.local_solver == nullptr && - parameters_.generated_local_solver == nullptr) { - throw ::gko::InvalidStateError( - __FILE__, __LINE__, __func__, + if (!parameters_.local_solver && !parameters_.generated_local_solver) { + GKO_INVALID_STATE( "Requires either a generated solver or an solver factory"); } diff --git a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp index 5bce97fb414..1b34faff7c4 100644 --- a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp +++ b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp @@ -95,6 +95,7 @@ class Schwarz * Local solver factory. */ GKO_DEFERRED_FACTORY_PARAMETER(local_solver, LinOpFactory); + /** * Generated Inner solvers. */ diff --git a/test/mpi/preconditioner/schwarz.cpp b/test/mpi/preconditioner/schwarz.cpp index 2241be8f535..506a8d1320f 100644 --- a/test/mpi/preconditioner/schwarz.cpp +++ b/test/mpi/preconditioner/schwarz.cpp @@ -196,7 +196,12 @@ TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfInvalidState) .on(this->exec); ASSERT_THROW(schwarz->generate(this->dist_mat), gko::InvalidStateError); +} + +TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfNoSolverProvided) +{ + using prec = typename TestFixture::dist_prec_type; auto schwarz_no_solver = prec::build().on(this->exec); ASSERT_THROW(schwarz_no_solver->generate(this->dist_mat), gko::InvalidStateError); @@ -260,21 +265,19 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) .with_local_solver(this->local_solver_factory) .on(this->exec) ->generate(this->dist_mat); - auto precond_pregen = prec::build() .with_generated_local_solver(local_solver) .on(this->exec) ->generate(this->dist_mat); - auto dist_x = gko::share(this->dist_x->clone()); auto dist_x_pregen = gko::share(this->dist_x->clone()); precond->apply(this->dist_b.get(), dist_x.get()); precond->apply(this->dist_b.get(), dist_x_pregen.get()); - GKO_ASSERT_MTX_NEAR( - dist_x->get_local_vector(), dist_x_pregen->get_local_vector(), - r::value); + GKO_ASSERT_MTX_NEAR(dist_x->get_local_vector(), + dist_x_pregen->get_local_vector(), + r::value); } From b59669dfab8eeccd0da34da850d9f8fee39f82bd Mon Sep 17 00:00:00 2001 From: Gregor Olenik Date: Wed, 18 Oct 2023 14:52:41 +0200 Subject: [PATCH 7/8] Add review suggestions Co-authored-by: Yuhsiang Tsai --- core/distributed/preconditioner/schwarz.cpp | 26 ++++++++++++++----- .../distributed/preconditioner/schwarz.hpp | 10 +++++-- test/mpi/preconditioner/schwarz.cpp | 5 ++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/core/distributed/preconditioner/schwarz.cpp b/core/distributed/preconditioner/schwarz.cpp index dd3f86a1cd9..45536c9df87 100644 --- a/core/distributed/preconditioner/schwarz.cpp +++ b/core/distributed/preconditioner/schwarz.cpp @@ -98,6 +98,20 @@ void Schwarz::apply_impl( } +template +void Schwarz::set_solver( + std::shared_ptr new_solver) +{ + auto exec = this->get_executor(); + if (new_solver) { + if (new_solver->get_executor() != exec) { + new_solver = gko::clone(exec, new_solver); + } + } + this->local_solver_ = new_solver; +} + + template void Schwarz::generate( std::shared_ptr system_matrix) @@ -113,13 +127,13 @@ void Schwarz::generate( } if (parameters_.local_solver) { - this->local_solver_ = parameters_.local_solver->generate( - as>( - system_matrix) - ->get_local_matrix()); + this->set_solver(gko::share(parameters_.local_solver->generate( + as>(system_matrix) + ->get_local_matrix()))); + } else { - this->local_solver_ = parameters_.generated_local_solver; + this->set_solver(parameters_.generated_local_solver); } } diff --git a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp index 1b34faff7c4..e7cd2b1d471 100644 --- a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp +++ b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp @@ -99,7 +99,7 @@ class Schwarz /** * Generated Inner solvers. */ - std::shared_ptr GKO_FACTORY_PARAMETER( + std::shared_ptr GKO_FACTORY_PARAMETER_SCALAR( generated_local_solver, nullptr); }; GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory); @@ -136,7 +136,6 @@ class Schwarz */ void generate(std::shared_ptr system_matrix); - void apply_impl(const LinOp* b, LinOp* x) const override; template @@ -146,6 +145,13 @@ class Schwarz LinOp* x) const override; private: + /** + * Sets the solver operator used as the local solver. + * + * @param new_solver the new local solver + */ + void set_solver(std::shared_ptr new_solver); + std::shared_ptr local_solver_; }; diff --git a/test/mpi/preconditioner/schwarz.cpp b/test/mpi/preconditioner/schwarz.cpp index 506a8d1320f..f0181cad39a 100644 --- a/test/mpi/preconditioner/schwarz.cpp +++ b/test/mpi/preconditioner/schwarz.cpp @@ -203,6 +203,7 @@ TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfNoSolverProvided) { using prec = typename TestFixture::dist_prec_type; auto schwarz_no_solver = prec::build().on(this->exec); + ASSERT_THROW(schwarz_no_solver->generate(this->dist_mat), gko::InvalidStateError); } @@ -260,7 +261,7 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) auto local_solver = gko::share(local_prec_type::build() .with_max_block_size(1u) .on(this->exec) - ->generate(this->non_dist_mat)); + ->generate(this->dist_mat->get_local_matrix())); auto precond = prec::build() .with_local_solver(this->local_solver_factory) .on(this->exec) @@ -273,7 +274,7 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) auto dist_x_pregen = gko::share(this->dist_x->clone()); precond->apply(this->dist_b.get(), dist_x.get()); - precond->apply(this->dist_b.get(), dist_x_pregen.get()); + precond_pregen->apply(this->dist_b.get(), dist_x_pregen.get()); GKO_ASSERT_MTX_NEAR(dist_x->get_local_vector(), dist_x_pregen->get_local_vector(), From 30c9a7ca4c5880044f9f914381231d1aebd530ab Mon Sep 17 00:00:00 2001 From: ginkgo-bot Date: Mon, 23 Oct 2023 08:05:19 +0000 Subject: [PATCH 8/8] Format files Co-authored-by: Gregor Olenik --- test/mpi/preconditioner/schwarz.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/mpi/preconditioner/schwarz.cpp b/test/mpi/preconditioner/schwarz.cpp index f0181cad39a..3c6dbf33a52 100644 --- a/test/mpi/preconditioner/schwarz.cpp +++ b/test/mpi/preconditioner/schwarz.cpp @@ -258,10 +258,11 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver) using cg = typename TestFixture::solver_type; using prec = typename TestFixture::dist_prec_type; - auto local_solver = gko::share(local_prec_type::build() - .with_max_block_size(1u) - .on(this->exec) - ->generate(this->dist_mat->get_local_matrix())); + auto local_solver = + gko::share(local_prec_type::build() + .with_max_block_size(1u) + .on(this->exec) + ->generate(this->dist_mat->get_local_matrix())); auto precond = prec::build() .with_local_solver(this->local_solver_factory) .on(this->exec)