diff --git a/core/distributed/row_gatherer.cpp b/core/distributed/row_gatherer.cpp index 3cc80b2e7f7..7466c90439a 100644 --- a/core/distributed/row_gatherer.cpp +++ b/core/distributed/row_gatherer.cpp @@ -161,6 +161,13 @@ RowGatherer::RowGatherer( } +template +const LocalIndexType* RowGatherer::get_const_row_idxs() const +{ + return send_idxs_.get_const_data(); +} + + template RowGatherer::RowGatherer(std::shared_ptr exec, mpi::communicator comm) diff --git a/core/multigrid/pgm.cpp b/core/multigrid/pgm.cpp index 6234f072dd5..7de5c147686 100644 --- a/core/multigrid/pgm.cpp +++ b/core/multigrid/pgm.cpp @@ -264,18 +264,15 @@ array Pgm::communicate_non_local_agg( { auto exec = gko::as(matrix)->get_executor(); const auto comm = matrix->get_communicator(); - auto send_sizes = matrix->send_sizes_; - auto recv_sizes = matrix->recv_sizes_; - auto send_offsets = matrix->send_offsets_; - auto recv_offsets = matrix->recv_offsets_; - auto gather_idxs = matrix->gather_idxs_; - auto total_send_size = send_offsets.back(); - auto total_recv_size = recv_offsets.back(); + auto coll_comm = matrix->row_gatherer_->get_collective_communicator(); + auto total_send_size = coll_comm->get_send_size(); + auto total_recv_size = coll_comm->get_recv_size(); + auto row_gatherer = matrix->row_gatherer_; array send_agg(exec, total_send_size); exec->run(pgm::make_gather_index( send_agg.get_size(), local_agg.get_const_data(), - gather_idxs.get_const_data(), send_agg.get_data())); + row_gatherer->get_const_row_idxs(), send_agg.get_data())); // temporary index map that contains no remote connections to map // local indices to global @@ -296,16 +293,16 @@ array Pgm::communicate_non_local_agg( seng_global_agg.get_data(), host_send_buffer.get_data()); } - auto type = experimental::mpi::type_impl::get_type(); const auto send_ptr = use_host_buffer ? host_send_buffer.get_const_data() : seng_global_agg.get_const_data(); auto recv_ptr = use_host_buffer ? host_recv_buffer.get_data() : non_local_agg.get_data(); exec->synchronize(); - comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr, - send_sizes.data(), send_offsets.data(), type, recv_ptr, - recv_sizes.data(), recv_offsets.data(), type); + coll_comm + ->i_all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr, + recv_ptr) + .wait(); if (use_host_buffer) { exec->copy_from(exec->get_master(), total_recv_size, recv_ptr, non_local_agg.get_data()); diff --git a/include/ginkgo/core/distributed/row_gatherer.hpp b/include/ginkgo/core/distributed/row_gatherer.hpp index f2d53f0f03a..3f5bea3ea9a 100644 --- a/include/ginkgo/core/distributed/row_gatherer.hpp +++ b/include/ginkgo/core/distributed/row_gatherer.hpp @@ -100,6 +100,13 @@ class RowGatherer final std::shared_ptr get_collective_communicator() const; + /** + * Read access to the (local) rows indices + * + * @return the (local) row indices that are gathered + */ + const LocalIndexType* get_const_row_idxs() const; + /** * Creates a distributed::RowGatherer from a given collective communicator * and index map.