diff --git a/core/distributed/row_gatherer.cpp b/core/distributed/row_gatherer.cpp index 9b6292773f3..3cc80b2e7f7 100644 --- a/core/distributed/row_gatherer.cpp +++ b/core/distributed/row_gatherer.cpp @@ -38,77 +38,85 @@ void RowGatherer::apply_impl(const LinOp* alpha, const LinOp* b, template -std::future RowGatherer::apply_async( - ptr_param b, ptr_param x) const +mpi::request RowGatherer::apply_async(ptr_param b, + ptr_param x) const { - auto op = [b = b.get(), x = x.get(), rg = this->shared_from_this(), - id = current_id_++] { - // ensure that the communications are executed in the order - // the apply_async were called - while (id > rg->active_id_.load()) { - std::this_thread::yield(); - } - - // dispatch global vector - run, std::complex>( - b, [&](const auto* b_global) { - using ValueType = - typename std::decay_t::value_type; - // dispatch local vector with the same precision as the global - // vector - ::gko::precision_dispatch( - [&](auto* x_local) { - auto exec = rg->get_executor(); - - auto use_host_buffer = mpi::requires_host_buffer( - exec, rg->coll_comm_->get_base_communicator()); - auto mpi_exec = - use_host_buffer ? exec->get_master() : exec; - - auto b_local = b_global->get_local_vector(); - rg->send_buffer_.template init( - mpi_exec, dim<2>(rg->coll_comm_->get_send_size(), - b_local->get_size()[1])); - rg->send_buffer_.template get()->fill(0.0); - b_local->row_gather( - &rg->send_idxs_, - rg->send_buffer_.template get()); - - if (use_host_buffer) { - rg->recv_buffer_.template init( - mpi_exec, x_local->get_size()); - } - - auto recv_ptr = - use_host_buffer - ? rg->recv_buffer_.template get() - ->get_values() - : x_local->get_values(); - auto send_ptr = - rg->send_buffer_.template get() - ->get_values(); - - mpi_exec->synchronize(); - mpi::contiguous_type type( - b_local->get_size()[1], - mpi::type_impl::get_type()); - auto g = exec->get_scoped_device_id_guard(); - auto req = rg->coll_comm_->i_all_to_all_v( - mpi_exec, send_ptr, type.get(), recv_ptr, - type.get()); - req.wait(); - - if (use_host_buffer) { - x_local->copy_from( - rg->recv_buffer_.template get()); - } - }, - x); - }); - - rg->active_id_++; - }; - return std::async(std::launch::async, op); + int is_inactive; + MPI_Status status; + GKO_ASSERT_NO_MPI_ERRORS( + MPI_Request_get_status(req_listener_, &is_inactive, &status)); + // This is untestable. Some processes might complete the previous request + // while others don't, so it's impossible to create a predictable behavior + // for a test. + GKO_THROW_IF_INVALID(is_inactive, + "Tried to call RowGatherer::apply_async while there " + "is already an active communication. Please use the " + "overload with a workspace to handle multiple " + "connections."); + + auto req = apply_async(b, x, send_workspace_); + req_listener_ = *req.get(); + return req; +} + + +template +mpi::request RowGatherer::apply_async( + ptr_param b, ptr_param x, array& workspace) const +{ + mpi::request req; + + // dispatch global vector + run, std::complex>( + b.get(), [&](const auto* b_global) { + using ValueType = + typename std::decay_t::value_type; + // dispatch local vector with the same precision as the global + // vector + ::gko::precision_dispatch( + [&](auto* x_local) { + auto exec = this->get_executor(); + + auto use_host_buffer = mpi::requires_host_buffer( + exec, coll_comm_->get_base_communicator()); + auto mpi_exec = use_host_buffer ? exec->get_master() : exec; + + GKO_THROW_IF_INVALID( + !use_host_buffer || mpi_exec->memory_accessible( + x_local->get_executor()), + "The receive buffer uses device memory, but MPI " + "support of device memory is not available. Please " + "provide a host buffer or enable MPI support for " + "device memory."); + + auto b_local = b_global->get_local_vector(); + + dim<2> send_size(coll_comm_->get_send_size(), + b_local->get_size()[1]); + workspace.set_executor(mpi_exec); + workspace.resize_and_reset(sizeof(ValueType) * + send_size[0] * send_size[1]); + auto send_buffer = matrix::Dense::create( + mpi_exec, send_size, + make_array_view( + mpi_exec, send_size[0] * send_size[1], + reinterpret_cast(workspace.get_data())), + send_size[1]); + b_local->row_gather(&send_idxs_, send_buffer); + + auto recv_ptr = x_local->get_values(); + auto send_ptr = send_buffer->get_values(); + + mpi_exec->synchronize(); + mpi::contiguous_type type( + b_local->get_size()[1], + mpi::type_impl::get_type()); + req = coll_comm_->i_all_to_all_v( + mpi_exec, send_ptr, type.get(), recv_ptr, type.get()); + }, + x.get()); + }); + return req; } @@ -130,7 +138,9 @@ RowGatherer::RowGatherer( exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}), DistributedBase(coll_comm->get_base_communicator()), coll_comm_(std::move(coll_comm)), - send_idxs_(exec) + send_idxs_(exec), + send_workspace_(exec), + req_listener_(MPI_REQUEST_NULL) { // check that the coll_comm_ and imap have the same recv size // the same check for the send size is not possible, since the @@ -157,7 +167,9 @@ RowGatherer::RowGatherer(std::shared_ptr exec, : EnableDistributedLinOp>(exec), DistributedBase(comm), coll_comm_(std::make_shared(comm)), - send_idxs_(exec) + send_idxs_(exec), + send_workspace_(exec), + req_listener_(MPI_REQUEST_NULL) {} @@ -165,7 +177,9 @@ template RowGatherer::RowGatherer(RowGatherer&& o) noexcept : EnableDistributedLinOp>(o.get_executor()), DistributedBase(o.get_communicator()), - send_idxs_(o.get_executor()) + send_idxs_(o.get_executor()), + send_workspace_(o.get_executor()), + req_listener_(MPI_REQUEST_NULL) { *this = std::move(o); } @@ -195,6 +209,8 @@ RowGatherer& RowGatherer::operator=( o.coll_comm_, std::make_shared(o.get_communicator())); send_idxs_ = std::move(o.send_idxs_); + send_workspace_ = std::move(o.send_workspace_); + req_listener_ = std::exchange(o.req_listener_, MPI_REQUEST_NULL); } return *this; } diff --git a/include/ginkgo/core/distributed/row_gatherer.hpp b/include/ginkgo/core/distributed/row_gatherer.hpp index 78eba57556e..f2d53f0f03a 100644 --- a/include/ginkgo/core/distributed/row_gatherer.hpp +++ b/include/ginkgo/core/distributed/row_gatherer.hpp @@ -63,17 +63,34 @@ class RowGatherer final /** * Asynchronous version of LinOp::apply. * - * It is asynchronous only wrt. the calling thread. Multiple calls to this - * function will execute in order, they are not asynchronous with each - * other. + * @warning Only one mpi::request can be active at any given time. This + * function will throw if another request is already active. * * @param b the input distributed::Vector * @param x the output matrix::Dense with the rows gathered from b - * @return a future for this task. The task is guarantueed to completed - * after `.wait()` has been called on the future. + * @return a mpi::request for this task. The task is guaranteed to + * be completed only after `.wait()` has been called on it. */ - std::future apply_async(ptr_param b, - ptr_param x) const; + mpi::request apply_async(ptr_param b, + ptr_param x) const; + + /** + * Asynchronous version of LinOp::apply. + * + * @warning Calling this multiple times with the same workspace and without + * waiting on each previous request will lead to incorrect + * data transfers. + * + * @param b the input distributed::Vector + * @param x the output matrix::Dense with the rows gathered from b + * @param workspace a workspace to store temporary data for the operation. + * This might not be modified before the request is + * waited on. + * @return a mpi::request for this task. The task is guaranteed to + * be completed only after `.wait()` has been called on it. + */ + mpi::request apply_async(ptr_param b, ptr_param x, + array& workspace) const; /** * Get the used collective communicator. @@ -158,11 +175,9 @@ class RowGatherer final array send_idxs_; - detail::AnyDenseCache send_buffer_; - detail::AnyDenseCache recv_buffer_; + mutable array send_workspace_; - mutable int64 current_id_{0}; - mutable std::atomic active_id_{0}; + mutable MPI_Request req_listener_{MPI_REQUEST_NULL}; };