diff --git a/core/distributed/matrix.cpp b/core/distributed/matrix.cpp index 815bfa59fc2..a923d3e237d 100644 --- a/core/distributed/matrix.cpp +++ b/core/distributed/matrix.cpp @@ -5,6 +5,7 @@ #include "ginkgo/core/distributed/matrix.hpp" #include +#include #include #include #include @@ -30,57 +31,14 @@ GKO_REGISTER_OPERATION(separate_local_nonlocal, template void initialize_communication_pattern( - std::shared_ptr exec, mpi::communicator comm, const index_map& imap, - std::vector& recv_sizes, - std::vector& recv_offsets, - std::vector& send_sizes, - std::vector& send_offsets, - array& gather_idxs) + std::shared_ptr>& row_gatherer) { - // exchange step 1: determine recv_sizes, send_sizes, send_offsets - auto host_recv_targets = - make_temporary_clone(exec->get_master(), &imap.get_remote_target_ids()); - auto host_offsets = make_temporary_clone( - exec->get_master(), &imap.get_remote_global_idxs().get_offsets()); - auto compute_recv_sizes = [](const auto* recv_targets, size_type size, - const auto* offsets, auto& recv_sizes) { - for (size_type i = 0; i < size; ++i) { - recv_sizes[recv_targets[i]] = offsets[i + 1] - offsets[i]; - } - }; - std::fill(recv_sizes.begin(), recv_sizes.end(), 0); - compute_recv_sizes(host_recv_targets->get_const_data(), - host_recv_targets->get_size(), - host_offsets->get_const_data(), recv_sizes); - std::partial_sum(recv_sizes.begin(), recv_sizes.end(), - recv_offsets.begin() + 1); - comm.all_to_all(exec, recv_sizes.data(), 1, send_sizes.data(), 1); - std::partial_sum(send_sizes.begin(), send_sizes.end(), - send_offsets.begin() + 1); - send_offsets[0] = 0; - recv_offsets[0] = 0; - - // exchange step 2: exchange gather_idxs from receivers to senders - auto recv_gather_idxs = - make_const_array_view( - imap.get_executor(), imap.get_non_local_size(), - imap.get_remote_local_idxs().get_const_flat_data()) - .copy_to_array(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_gather_idxs.set_executor(exec->get_master()); - gather_idxs.clear(); - gather_idxs.set_executor(exec->get_master()); - } - gather_idxs.resize_and_reset(send_offsets.back()); - comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, - recv_gather_idxs.get_const_data(), recv_sizes.data(), - recv_offsets.data(), gather_idxs.get_data(), - send_sizes.data(), send_offsets.data()); - if (use_host_buffer) { - gather_idxs.set_executor(exec); - } + row_gatherer = RowGatherer::create( + row_gatherer->get_executor(), + row_gatherer->get_collective_communicator()->create_with_same_type( + row_gatherer->get_communicator(), imap), + imap); } @@ -101,12 +59,8 @@ Matrix::Matrix( : EnableDistributedLinOp< Matrix>{exec}, DistributedBase{comm}, - imap_(exec), - send_offsets_(comm.size() + 1), - send_sizes_(comm.size()), - recv_offsets_(comm.size() + 1), - recv_sizes_(comm.size()), - gather_idxs_{exec}, + row_gatherer_{RowGatherer::create(exec, comm)}, + imap_{exec}, one_scalar_{}, local_mtx_{local_matrix_template->clone(exec)}, non_local_mtx_{non_local_matrix_template->clone(exec)} @@ -128,12 +82,8 @@ Matrix::Matrix( : EnableDistributedLinOp< Matrix>{exec}, DistributedBase{comm}, - imap_(exec), - send_offsets_(comm.size() + 1), - send_sizes_(comm.size()), - recv_offsets_(comm.size() + 1), - recv_sizes_(comm.size()), - gather_idxs_{exec}, + row_gatherer_{RowGatherer::create(exec, comm)}, + imap_{exec}, one_scalar_{}, non_local_mtx_(::gko::matrix::Coo::create( exec, dim<2>{local_linop->get_size()[0], 0})) @@ -152,12 +102,8 @@ Matrix::Matrix( : EnableDistributedLinOp< Matrix>{exec}, DistributedBase{comm}, + row_gatherer_(RowGatherer::create(exec, comm)), imap_(std::move(imap)), - send_offsets_(comm.size() + 1), - send_sizes_(comm.size()), - recv_offsets_(comm.size() + 1), - recv_sizes_(comm.size()), - gather_idxs_{exec}, one_scalar_{} { this->set_size({imap_.get_global_size(), imap_.get_global_size()}); @@ -166,9 +112,7 @@ Matrix::Matrix( one_scalar_.init(exec, dim<2>{1, 1}); one_scalar_->fill(one()); - initialize_communication_pattern( - this->get_executor(), this->get_communicator(), imap_, recv_sizes_, - recv_offsets_, send_sizes_, send_offsets_, gather_idxs_); + initialize_communication_pattern(imap_, row_gatherer_); } @@ -235,12 +179,8 @@ void Matrix::convert_to( result->get_communicator().size()); result->local_mtx_->copy_from(this->local_mtx_); result->non_local_mtx_->copy_from(this->non_local_mtx_); + result->row_gatherer_->copy_from(this->row_gatherer_); result->imap_ = this->imap_; - result->gather_idxs_ = this->gather_idxs_; - result->send_offsets_ = this->send_offsets_; - result->recv_offsets_ = this->recv_offsets_; - result->recv_sizes_ = this->recv_sizes_; - result->send_sizes_ = this->send_sizes_; result->set_size(this->get_size()); } @@ -254,12 +194,8 @@ void Matrix::move_to( result->get_communicator().size()); result->local_mtx_->move_from(this->local_mtx_); result->non_local_mtx_->move_from(this->non_local_mtx_); + result->row_gatherer_->move_from(this->row_gatherer_); result->imap_ = std::move(this->imap_); - result->gather_idxs_ = std::move(this->gather_idxs_); - result->send_offsets_ = std::move(this->send_offsets_); - result->recv_offsets_ = std::move(this->recv_offsets_); - result->recv_sizes_ = std::move(this->recv_sizes_); - result->send_sizes_ = std::move(this->send_sizes_); result->set_size(this->get_size()); this->set_size({}); } @@ -282,7 +218,6 @@ void Matrix::read_distributed( auto local_part = comm.rank(); // set up LinOp sizes - auto num_parts = static_cast(row_partition->get_num_parts()); auto global_num_rows = row_partition->get_size(); auto global_num_cols = col_partition->get_size(); dim<2> global_dim{global_num_rows, global_num_cols}; @@ -329,9 +264,7 @@ void Matrix::read_distributed( as>(this->non_local_mtx_) ->read(std::move(non_local_data)); - initialize_communication_pattern(exec, comm, imap_, recv_sizes_, - recv_offsets_, send_sizes_, send_offsets_, - gather_idxs_); + initialize_communication_pattern(imap_, row_gatherer_); } @@ -373,55 +306,6 @@ void Matrix::read_distributed( } -template -mpi::request Matrix::communicate( - const local_vector_type* local_b) const -{ - // This function can never return early! - // Even if the non-local part is empty, i.e. this process doesn't need - // any data from other processes, the used MPI calls are collective - // operations. They need to be called on all processes, even if a process - // might not communicate any data. - auto exec = this->get_executor(); - const auto comm = this->get_communicator(); - auto num_cols = local_b->get_size()[1]; - auto send_size = send_offsets_.back(); - auto recv_size = recv_offsets_.back(); - auto send_dim = dim<2>{static_cast(send_size), num_cols}; - auto recv_dim = dim<2>{static_cast(recv_size), num_cols}; - recv_buffer_.init(exec, recv_dim); - send_buffer_.init(exec, send_dim); - - local_b->row_gather(&gather_idxs_, send_buffer_.get()); - - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - host_recv_buffer_.init(exec->get_master(), recv_dim); - host_send_buffer_.init(exec->get_master(), send_dim); - host_send_buffer_->copy_from(send_buffer_.get()); - } - - mpi::contiguous_type type(num_cols, mpi::type_impl::get_type()); - auto send_ptr = use_host_buffer ? host_send_buffer_->get_const_values() - : send_buffer_->get_const_values(); - auto recv_ptr = use_host_buffer ? host_recv_buffer_->get_values() - : recv_buffer_->get_values(); - exec->synchronize(); -#ifdef GINKGO_HAVE_OPENMPI_PRE_4_1_X - comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr, - send_sizes_.data(), send_offsets_.data(), type.get(), - recv_ptr, recv_sizes_.data(), recv_offsets_.data(), - type.get()); - return {}; -#else - return comm.i_all_to_all_v( - use_host_buffer ? exec->get_master() : exec, send_ptr, - send_sizes_.data(), send_offsets_.data(), type.get(), recv_ptr, - recv_sizes_.data(), recv_offsets_.data(), type.get()); -#endif -} - - template void Matrix::apply_impl( const LinOp* b, LinOp* x) const @@ -437,16 +321,22 @@ void Matrix::apply_impl( dense_x->get_local_values()), dense_x->get_local_vector()->get_stride()); + auto exec = this->get_executor(); auto comm = this->get_communicator(); - auto req = this->communicate(dense_b->get_local_vector()); + auto recv_dim = + dim<2>{static_cast( + row_gatherer_->get_collective_communicator() + ->get_recv_size()), + dense_b->get_size()[1]}; + auto recv_exec = mpi::requires_host_buffer(exec, comm) + ? exec->get_master() + : exec; + recv_buffer_.init(recv_exec, recv_dim); + auto req = + this->row_gatherer_->apply_async(dense_b, recv_buffer_.get()); local_mtx_->apply(dense_b->get_local_vector(), local_x); req.wait(); - auto exec = this->get_executor(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_buffer_->copy_from(host_recv_buffer_.get()); - } non_local_mtx_->apply(one_scalar_.get(), recv_buffer_.get(), one_scalar_.get(), local_x); }, @@ -470,17 +360,23 @@ void Matrix::apply_impl( dense_x->get_local_values()), dense_x->get_local_vector()->get_stride()); + auto exec = this->get_executor(); auto comm = this->get_communicator(); - auto req = this->communicate(dense_b->get_local_vector()); + auto recv_dim = + dim<2>{static_cast( + row_gatherer_->get_collective_communicator() + ->get_recv_size()), + dense_b->get_size()[1]}; + auto recv_exec = mpi::requires_host_buffer(exec, comm) + ? exec->get_master() + : exec; + recv_buffer_.init(recv_exec, recv_dim); + auto req = + this->row_gatherer_->apply_async(dense_b, recv_buffer_.get()); local_mtx_->apply(local_alpha, dense_b->get_local_vector(), local_beta, local_x); req.wait(); - auto exec = this->get_executor(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_buffer_->copy_from(host_recv_buffer_.get()); - } non_local_mtx_->apply(local_alpha, recv_buffer_.get(), one_scalar_.get(), local_x); }, @@ -563,6 +459,8 @@ Matrix::Matrix(const Matrix& other) : EnableDistributedLinOp>{other.get_executor()}, DistributedBase{other.get_communicator()}, + row_gatherer_{RowGatherer::create( + other.get_executor(), other.get_communicator())}, imap_(other.get_executor()) { *this = other; @@ -575,6 +473,8 @@ Matrix::Matrix( : EnableDistributedLinOp>{other.get_executor()}, DistributedBase{other.get_communicator()}, + row_gatherer_{RowGatherer::create( + other.get_executor(), other.get_communicator())}, imap_(other.get_executor()) { *this = std::move(other); @@ -592,12 +492,8 @@ Matrix::operator=( this->set_size(other.get_size()); local_mtx_->copy_from(other.local_mtx_); non_local_mtx_->copy_from(other.non_local_mtx_); + row_gatherer_->copy_from(other.row_gatherer_); imap_ = other.imap_; - gather_idxs_ = other.gather_idxs_; - send_offsets_ = other.send_offsets_; - recv_offsets_ = other.recv_offsets_; - send_sizes_ = other.send_sizes_; - recv_sizes_ = other.recv_sizes_; one_scalar_.init(this->get_executor(), dim<2>{1, 1}); one_scalar_->fill(one()); } @@ -616,12 +512,8 @@ Matrix::operator=(Matrix&& other) other.set_size({}); local_mtx_->move_from(other.local_mtx_); non_local_mtx_->move_from(other.non_local_mtx_); + row_gatherer_->move_from(other.row_gatherer_); imap_ = std::move(other.imap_); - gather_idxs_ = std::move(other.gather_idxs_); - send_offsets_ = std::move(other.send_offsets_); - recv_offsets_ = std::move(other.recv_offsets_); - send_sizes_ = std::move(other.send_sizes_); - recv_sizes_ = std::move(other.recv_sizes_); one_scalar_.init(this->get_executor(), dim<2>{1, 1}); one_scalar_->fill(one()); } diff --git a/include/ginkgo/core/distributed/matrix.hpp b/include/ginkgo/core/distributed/matrix.hpp index 09070d0ca55..cee947b1cae 100644 --- a/include/ginkgo/core/distributed/matrix.hpp +++ b/include/ginkgo/core/distributed/matrix.hpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace gko { @@ -610,32 +611,15 @@ class Matrix std::shared_ptr local_linop, std::shared_ptr non_local_linop); - /** - * Starts a non-blocking communication of the values of b that are shared - * with other processors. - * - * @param local_b The full local vector to be communicated. The subset of - * shared values is automatically extracted. - * @return MPI request for the non-blocking communication. - */ - mpi::request communicate(const local_vector_type* local_b) const; - void apply_impl(const LinOp* b, LinOp* x) const override; void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x) const override; private: + std::shared_ptr> row_gatherer_; index_map imap_; - std::vector send_offsets_; - std::vector send_sizes_; - std::vector recv_offsets_; - std::vector recv_sizes_; - array gather_idxs_; gko::detail::DenseCache one_scalar_; - gko::detail::DenseCache host_send_buffer_; - gko::detail::DenseCache host_recv_buffer_; - gko::detail::DenseCache send_buffer_; gko::detail::DenseCache recv_buffer_; std::shared_ptr local_mtx_; std::shared_ptr non_local_mtx_;