diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index f7f1e00f17b..b814e660aca 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -135,6 +135,7 @@ if(GINKGO_BUILD_MPI) distributed/matrix.cpp distributed/neighborhood_communicator.cpp distributed/partition_helpers.cpp + distributed/row_gatherer.cpp distributed/vector.cpp distributed/preconditioner/schwarz.cpp) endif() diff --git a/core/distributed/row_gatherer.cpp b/core/distributed/row_gatherer.cpp new file mode 100644 index 00000000000..8dfba55a9c1 --- /dev/null +++ b/core/distributed/row_gatherer.cpp @@ -0,0 +1,215 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "ginkgo/core/distributed/row_gatherer.hpp" + +#include +#include +#include +#include +#include + +#include "core/base/dispatch_helper.hpp" + +namespace gko { +namespace experimental { +namespace distributed { + + +#if GINKGO_HAVE_OPENMPI_POST_4_1_X +using DefaultCollComm = mpi::NeighborhoodCommunicator; +#else +using DefaultCollComm = mpi::DenseCommunicator; +#endif + + +template +void RowGatherer::apply_impl(const LinOp* b, LinOp* x) const +{ + apply_async(b, x).wait(); +} + + +template +void RowGatherer::apply_impl(const LinOp* alpha, const LinOp* b, + const LinOp* beta, LinOp* x) const + GKO_NOT_IMPLEMENTED; + + +template +std::future RowGatherer::apply_async( + ptr_param b, ptr_param x) const +{ + auto op = [b = b.get(), x = x.get(), rg = this->shared_from_this(), + id = current_id_++] { + // ensure that the communications are executed in the order + // the apply_async were called + while (id > rg->active_id_.load()) { + std::this_thread::yield(); + } + + // dispatch global vector + run, std::complex>( + b, [&](const auto* b_global) { + using ValueType = + typename std::decay_t::value_type; + // dispatch local vector with the same precision as the global + // vector + ::gko::precision_dispatch( + [&](auto* x_local) { + auto exec = rg->get_executor(); + + auto b_local = b_global->get_local_vector(); + rg->send_buffer.template init( + b_local->get_executor(), + dim<2>(rg->coll_comm_->get_send_size(), + b_local->get_size()[1])); + rg->send_buffer.template get()->fill(0.0); + b_local->row_gather( + &rg->send_idxs_, + rg->send_buffer.template get()); + + auto recv_ptr = x_local->get_values(); + auto send_ptr = + rg->send_buffer.template get() + ->get_values(); + + exec->synchronize(); + mpi::contiguous_type type( + b_local->get_size()[1], + mpi::type_impl::get_type()); + auto g = exec->get_scoped_device_id_guard(); + auto req = rg->coll_comm_->i_all_to_all_v( + exec, send_ptr, type.get(), recv_ptr, type.get()); + req.wait(); + }, + x); + }); + + rg->active_id_++; + }; + return std::async(std::launch::async, op); +} + + +template +std::shared_ptr +RowGatherer::get_collective_communicator() const +{ + return coll_comm_; +} + + +template +template +RowGatherer::RowGatherer( + std::shared_ptr exec, + std::shared_ptr coll_comm, + const index_map& imap) + : EnableDistributedLinOp>( + exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}), + DistributedBase(coll_comm->get_base_communicator()), + coll_comm_(std::move(coll_comm)), + send_idxs_(exec) +{ + // check that the coll_comm_ and imap have the same recv size + // the same check for the send size is not possible, since the + // imap doesn't store send indices + GKO_THROW_IF_INVALID( + coll_comm_->get_recv_size() == imap.get_non_local_size(), + "The collective communicator doesn't match the index map."); + + auto comm = coll_comm_->get_base_communicator(); + auto inverse_comm = coll_comm_->create_inverse(); + + send_idxs_.resize_and_reset(coll_comm_->get_send_size()); + inverse_comm + ->i_all_to_all_v(exec, + imap.get_remote_local_idxs().get_const_flat_data(), + send_idxs_.get_data()) + .wait(); +} + + +template +RowGatherer::RowGatherer(std::shared_ptr exec, + mpi::communicator comm) + : EnableDistributedLinOp>(exec), + DistributedBase(comm), + coll_comm_(std::make_shared(comm)), + send_idxs_(exec) +{} + + +template +RowGatherer::RowGatherer(RowGatherer&& o) noexcept + : EnableDistributedLinOp>(o.get_executor()), + DistributedBase(o.get_communicator()), + send_idxs_(o.get_executor()) +{ + *this = std::move(o); +} + + +template +RowGatherer& RowGatherer::operator=( + const RowGatherer& o) +{ + if (this != &o) { + this->set_size(o.get_size()); + coll_comm_ = o.coll_comm_; + send_idxs_ = o.send_idxs_; + } + return *this; +} + + +template +RowGatherer& RowGatherer::operator=( + RowGatherer&& o) +{ + if (this != &o) { + this->set_size(o.get_size()); + o.set_size({}); + coll_comm_ = std::exchange( + o.coll_comm_, + std::make_shared(o.get_communicator())); + send_idxs_ = std::move(o.send_idxs_); + } + return *this; +} + + +template +RowGatherer::RowGatherer(const RowGatherer& o) + : EnableDistributedLinOp>(o.get_executor()), + DistributedBase(o.get_communicator()), + send_idxs_(o.get_executor()) +{ + *this = o; +} + + +#define GKO_DECLARE_ROW_GATHERER(_itype) class RowGatherer<_itype> + +GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_ROW_GATHERER); + +#undef GKO_DECLARE_ROW_GATHERER + + +#define GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR(_ltype, _gtype) \ + RowGatherer<_ltype>::RowGatherer( \ + std::shared_ptr exec, \ + std::shared_ptr coll_comm, \ + const index_map<_ltype, _gtype>& imap) + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR); + +#undef GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR + + +} // namespace distributed +} // namespace experimental +} // namespace gko diff --git a/include/ginkgo/core/distributed/row_gatherer.hpp b/include/ginkgo/core/distributed/row_gatherer.hpp new file mode 100644 index 00000000000..3af989d966c --- /dev/null +++ b/include/ginkgo/core/distributed/row_gatherer.hpp @@ -0,0 +1,173 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ +#define GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ + + +#include + + +#if GINKGO_BUILD_MPI + + +#include + +#include +#include +#include +#include +#include + + +namespace gko { +namespace experimental { +namespace distributed { + + +/** + * The distributed::RowGatherer gathers the rows of distributed::Vector that + * are located on other processes. + * + * Example usage: + * ```c++ + * auto coll_comm = std::make_shared(comm, + * imap); auto rg = distributed::RowGatherer::create(exec, coll_comm, + * imap); + * + * auto b = distributed::Vector::create(...); + * auto x = matrix::Dense::create(...); + * + * auto future = rg->apply_async(b, x); + * // do some computation that doesn't modify b, or access x + * future.wait(); + * // x now contains the gathered rows of b + * ``` + * Using the apply instead of the apply_async will lead to a blocking + * communication. + * + * @note Objects of this class are only available as shared_ptr, since the class + * is derived from std::enable_shared_from_this. + * + * @tparam LocalIndexType the index type for the stored indices + */ +template +class RowGatherer final + : public EnableDistributedLinOp>, + public DistributedBase, + public std::enable_shared_from_this> { + friend class EnableDistributedPolymorphicObject; + +public: + /** + * Asynchronous version of LinOp::apply. + * + * It is asynchronous only wrt. the calling thread. Multiple calls to this + * function will execute in order, they are not asynchronous with each + * other. + * + * @param b the input distributed::Vector + * @param x the output matrix::Dense with the rows gathered from b + * @return a future for this task. The task is guarantueed to completed + * after `.wait()` has been called on the future. + */ + std::future apply_async(ptr_param b, + ptr_param x) const; + + /** + * Get the used collective communicator. + * + * @return the used collective communicator + */ + std::shared_ptr + get_collective_communicator() const; + + /** + * Creates a distributed::RowGatherer from a given collective communicator + * and index map. + * + * @TODO: using a segmented array instead of the imap would probably be + * more general + * + * @tparam GlobalIndexType the global index type of the index map + * + * @param exec the executor + * @param coll_comm the collective communicator + * @param imap the index map defining which rows to gather + * + * @note The coll_comm and imap have to be compatible. The coll_comm must + * send and recv exactly as many rows as the imap defines. + * + * @return a shared_ptr to the created distributed::RowGatherer + */ + template = + sizeof(LocalIndexType)>> + static std::shared_ptr create( + std::shared_ptr exec, + std::shared_ptr coll_comm, + const index_map& imap) + { + return std::shared_ptr( + new RowGatherer(std::move(exec), std::move(coll_comm), imap)); + } + + /* + * Create method for an empty RowGatherer. + */ + static std::shared_ptr create( + std::shared_ptr exec, mpi::communicator comm) + { + return std::shared_ptr( + new RowGatherer(std::move(exec), std::move(comm))); + } + + RowGatherer(const RowGatherer& o); + + RowGatherer(RowGatherer&& o) noexcept; + + RowGatherer& operator=(const RowGatherer& o); + + RowGatherer& operator=(RowGatherer&& o); + +protected: + void apply_impl(const LinOp* b, LinOp* x) const override; + + void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, + LinOp* x) const override; + +private: + /** + * @copydoc RowGatherer::create(std::shared_ptr, std::shared_ptr, + * const index_map&) + */ + template + RowGatherer(std::shared_ptr exec, + std::shared_ptr coll_comm, + const index_map& imap); + + /** + * @copydoc RowGatherer::create(std::shared_ptr, mpi::communicator) + */ + RowGatherer(std::shared_ptr exec, mpi::communicator comm); + + std::shared_ptr coll_comm_; + + array send_idxs_; + + detail::AnyDenseCache send_buffer; + + mutable int64 current_id_{0}; + mutable std::atomic active_id_{0}; +}; + + +} // namespace distributed +} // namespace experimental +} // namespace gko + +#endif +#endif // GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ diff --git a/include/ginkgo/ginkgo.hpp b/include/ginkgo/ginkgo.hpp index 9d897ce8762..9051b95e5ef 100644 --- a/include/ginkgo/ginkgo.hpp +++ b/include/ginkgo/ginkgo.hpp @@ -72,6 +72,7 @@ #include +#include #include #include