Skip to content

Commit

Permalink
[dist-rg] handle copy to host buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Aug 16, 2024
1 parent f75b61f commit 01d9c56
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
41 changes: 30 additions & 11 deletions core/distributed/row_gatherer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,48 @@ std::future<void> RowGatherer<LocalIndexType>::apply_async(
[&](auto* x_local) {
auto exec = rg->get_executor();

auto use_host_buffer = mpi::requires_host_buffer(
exec, rg->coll_comm_->get_base_communicator());
auto mpi_exec =
use_host_buffer ? exec->get_master() : exec;

auto b_local = b_global->get_local_vector();
rg->send_buffer.template init<ValueType>(
b_local->get_executor(),
dim<2>(rg->coll_comm_->get_send_size(),
b_local->get_size()[1]));
rg->send_buffer.template get<ValueType>()->fill(0.0);
rg->send_buffer_.template init<ValueType>(
mpi_exec, dim<2>(rg->coll_comm_->get_send_size(),
b_local->get_size()[1]));
rg->send_buffer_.template get<ValueType>()->fill(0.0);
b_local->row_gather(
&rg->send_idxs_,
rg->send_buffer.template get<ValueType>());

auto recv_ptr = x_local->get_values();
rg->send_buffer_.template get<ValueType>());

if (use_host_buffer) {
rg->recv_buffer_.template init<ValueType>(
mpi_exec, x_local->get_size());
}

auto recv_ptr =
use_host_buffer
? rg->recv_buffer_.template get<ValueType>()
->get_values()
: x_local->get_values();
auto send_ptr =
rg->send_buffer.template get<ValueType>()
rg->send_buffer_.template get<ValueType>()
->get_values();

exec->synchronize();
mpi_exec->synchronize();
mpi::contiguous_type type(
b_local->get_size()[1],
mpi::type_impl<ValueType>::get_type());
auto g = exec->get_scoped_device_id_guard();
auto req = rg->coll_comm_->i_all_to_all_v(
exec, send_ptr, type.get(), recv_ptr, type.get());
mpi_exec, send_ptr, type.get(), recv_ptr,
type.get());
req.wait();

if (use_host_buffer) {
x_local->copy_from(
rg->recv_buffer_.template get<ValueType>());
}
},
x);
});
Expand Down
3 changes: 2 additions & 1 deletion include/ginkgo/core/distributed/row_gatherer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class RowGatherer final

array<LocalIndexType> send_idxs_;

detail::AnyDenseCache send_buffer;
detail::AnyDenseCache send_buffer_;
detail::AnyDenseCache recv_buffer_;

mutable int64 current_id_{0};
mutable std::atomic<int64> active_id_{0};
Expand Down

0 comments on commit 01d9c56

Please sign in to comment.