From 01d9c564334a421e3b3674e0397656138af12e74 Mon Sep 17 00:00:00 2001 From: Marcel Koch Date: Fri, 19 Apr 2024 17:32:04 +0200 Subject: [PATCH] [dist-rg] handle copy to host buffer --- core/distributed/row_gatherer.cpp | 41 ++++++++++++++----- .../ginkgo/core/distributed/row_gatherer.hpp | 3 +- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/core/distributed/row_gatherer.cpp b/core/distributed/row_gatherer.cpp index 8dfba55a9c1..9b6292773f3 100644 --- a/core/distributed/row_gatherer.cpp +++ b/core/distributed/row_gatherer.cpp @@ -60,29 +60,48 @@ std::future RowGatherer::apply_async( [&](auto* x_local) { auto exec = rg->get_executor(); + auto use_host_buffer = mpi::requires_host_buffer( + exec, rg->coll_comm_->get_base_communicator()); + auto mpi_exec = + use_host_buffer ? exec->get_master() : exec; + auto b_local = b_global->get_local_vector(); - rg->send_buffer.template init( - b_local->get_executor(), - dim<2>(rg->coll_comm_->get_send_size(), - b_local->get_size()[1])); - rg->send_buffer.template get()->fill(0.0); + rg->send_buffer_.template init( + mpi_exec, dim<2>(rg->coll_comm_->get_send_size(), + b_local->get_size()[1])); + rg->send_buffer_.template get()->fill(0.0); b_local->row_gather( &rg->send_idxs_, - rg->send_buffer.template get()); - - auto recv_ptr = x_local->get_values(); + rg->send_buffer_.template get()); + + if (use_host_buffer) { + rg->recv_buffer_.template init( + mpi_exec, x_local->get_size()); + } + + auto recv_ptr = + use_host_buffer + ? rg->recv_buffer_.template get() + ->get_values() + : x_local->get_values(); auto send_ptr = - rg->send_buffer.template get() + rg->send_buffer_.template get() ->get_values(); - exec->synchronize(); + mpi_exec->synchronize(); mpi::contiguous_type type( b_local->get_size()[1], mpi::type_impl::get_type()); auto g = exec->get_scoped_device_id_guard(); auto req = rg->coll_comm_->i_all_to_all_v( - exec, send_ptr, type.get(), recv_ptr, type.get()); + mpi_exec, send_ptr, type.get(), recv_ptr, + type.get()); req.wait(); + + if (use_host_buffer) { + x_local->copy_from( + rg->recv_buffer_.template get()); + } }, x); }); diff --git a/include/ginkgo/core/distributed/row_gatherer.hpp b/include/ginkgo/core/distributed/row_gatherer.hpp index 3af989d966c..78eba57556e 100644 --- a/include/ginkgo/core/distributed/row_gatherer.hpp +++ b/include/ginkgo/core/distributed/row_gatherer.hpp @@ -158,7 +158,8 @@ class RowGatherer final array send_idxs_; - detail::AnyDenseCache send_buffer; + detail::AnyDenseCache send_buffer_; + detail::AnyDenseCache recv_buffer_; mutable int64 current_id_{0}; mutable std::atomic active_id_{0};