Skip to content

Commit

Permalink
[dist-rg] use mpi request instead of future
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Aug 16, 2024
1 parent 01d9c56 commit 49dfd97
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 84 deletions.
162 changes: 89 additions & 73 deletions core/distributed/row_gatherer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,77 +38,85 @@ void RowGatherer<LocalIndexType>::apply_impl(const LinOp* alpha, const LinOp* b,


template <typename LocalIndexType>
std::future<void> RowGatherer<LocalIndexType>::apply_async(
ptr_param<const LinOp> b, ptr_param<LinOp> x) const
mpi::request RowGatherer<LocalIndexType>::apply_async(ptr_param<const LinOp> b,
ptr_param<LinOp> x) const
{
auto op = [b = b.get(), x = x.get(), rg = this->shared_from_this(),
id = current_id_++] {
// ensure that the communications are executed in the order
// the apply_async were called
while (id > rg->active_id_.load()) {
std::this_thread::yield();
}

// dispatch global vector
run<Vector, double, float, std::complex<double>, std::complex<float>>(
b, [&](const auto* b_global) {
using ValueType =
typename std::decay_t<decltype(*b_global)>::value_type;
// dispatch local vector with the same precision as the global
// vector
::gko::precision_dispatch<ValueType>(
[&](auto* x_local) {
auto exec = rg->get_executor();

auto use_host_buffer = mpi::requires_host_buffer(
exec, rg->coll_comm_->get_base_communicator());
auto mpi_exec =
use_host_buffer ? exec->get_master() : exec;

auto b_local = b_global->get_local_vector();
rg->send_buffer_.template init<ValueType>(
mpi_exec, dim<2>(rg->coll_comm_->get_send_size(),
b_local->get_size()[1]));
rg->send_buffer_.template get<ValueType>()->fill(0.0);
b_local->row_gather(
&rg->send_idxs_,
rg->send_buffer_.template get<ValueType>());

if (use_host_buffer) {
rg->recv_buffer_.template init<ValueType>(
mpi_exec, x_local->get_size());
}

auto recv_ptr =
use_host_buffer
? rg->recv_buffer_.template get<ValueType>()
->get_values()
: x_local->get_values();
auto send_ptr =
rg->send_buffer_.template get<ValueType>()
->get_values();

mpi_exec->synchronize();
mpi::contiguous_type type(
b_local->get_size()[1],
mpi::type_impl<ValueType>::get_type());
auto g = exec->get_scoped_device_id_guard();
auto req = rg->coll_comm_->i_all_to_all_v(
mpi_exec, send_ptr, type.get(), recv_ptr,
type.get());
req.wait();

if (use_host_buffer) {
x_local->copy_from(
rg->recv_buffer_.template get<ValueType>());
}
},
x);
});

rg->active_id_++;
};
return std::async(std::launch::async, op);
int is_inactive;
MPI_Status status;
GKO_ASSERT_NO_MPI_ERRORS(
MPI_Request_get_status(req_listener_, &is_inactive, &status));
// This is untestable. Some processes might complete the previous request
// while others don't, so it's impossible to create a predictable behavior
// for a test.
GKO_THROW_IF_INVALID(is_inactive,
"Tried to call RowGatherer::apply_async while there "
"is already an active communication. Please use the "
"overload with a workspace to handle multiple "
"connections.");

auto req = apply_async(b, x, send_workspace_);
req_listener_ = *req.get();
return req;
}


template <typename LocalIndexType>
mpi::request RowGatherer<LocalIndexType>::apply_async(
ptr_param<const LinOp> b, ptr_param<LinOp> x, array<char>& workspace) const
{
mpi::request req;

// dispatch global vector
run<Vector, double, float, std::complex<double>, std::complex<float>>(
b.get(), [&](const auto* b_global) {
using ValueType =
typename std::decay_t<decltype(*b_global)>::value_type;
// dispatch local vector with the same precision as the global
// vector
::gko::precision_dispatch<ValueType>(
[&](auto* x_local) {
auto exec = this->get_executor();

auto use_host_buffer = mpi::requires_host_buffer(
exec, coll_comm_->get_base_communicator());
auto mpi_exec = use_host_buffer ? exec->get_master() : exec;

GKO_THROW_IF_INVALID(
!use_host_buffer || mpi_exec->memory_accessible(
x_local->get_executor()),
"The receive buffer uses device memory, but MPI "
"support of device memory is not available. Please "
"provide a host buffer or enable MPI support for "
"device memory.");

auto b_local = b_global->get_local_vector();

dim<2> send_size(coll_comm_->get_send_size(),
b_local->get_size()[1]);
workspace.set_executor(mpi_exec);
workspace.resize_and_reset(sizeof(ValueType) *
send_size[0] * send_size[1]);
auto send_buffer = matrix::Dense<ValueType>::create(
mpi_exec, send_size,
make_array_view(
mpi_exec, send_size[0] * send_size[1],
reinterpret_cast<ValueType*>(workspace.get_data())),
send_size[1]);
b_local->row_gather(&send_idxs_, send_buffer);

auto recv_ptr = x_local->get_values();
auto send_ptr = send_buffer->get_values();

mpi_exec->synchronize();
mpi::contiguous_type type(
b_local->get_size()[1],
mpi::type_impl<ValueType>::get_type());
req = coll_comm_->i_all_to_all_v(
mpi_exec, send_ptr, type.get(), recv_ptr, type.get());
},
x.get());
});
return req;
}


Expand All @@ -130,7 +138,9 @@ RowGatherer<LocalIndexType>::RowGatherer(
exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}),
DistributedBase(coll_comm->get_base_communicator()),
coll_comm_(std::move(coll_comm)),
send_idxs_(exec)
send_idxs_(exec),
send_workspace_(exec),
req_listener_(MPI_REQUEST_NULL)
{
// check that the coll_comm_ and imap have the same recv size
// the same check for the send size is not possible, since the
Expand All @@ -157,15 +167,19 @@ RowGatherer<LocalIndexType>::RowGatherer(std::shared_ptr<const Executor> exec,
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(exec),
DistributedBase(comm),
coll_comm_(std::make_shared<DefaultCollComm>(comm)),
send_idxs_(exec)
send_idxs_(exec),
send_workspace_(exec),
req_listener_(MPI_REQUEST_NULL)
{}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()),
DistributedBase(o.get_communicator()),
send_idxs_(o.get_executor())
send_idxs_(o.get_executor()),
send_workspace_(o.get_executor()),
req_listener_(MPI_REQUEST_NULL)
{
*this = std::move(o);
}
Expand Down Expand Up @@ -195,6 +209,8 @@ RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
o.coll_comm_,
std::make_shared<DefaultCollComm>(o.get_communicator()));
send_idxs_ = std::move(o.send_idxs_);
send_workspace_ = std::move(o.send_workspace_);
req_listener_ = std::exchange(o.req_listener_, MPI_REQUEST_NULL);
}
return *this;
}
Expand Down
37 changes: 26 additions & 11 deletions include/ginkgo/core/distributed/row_gatherer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,34 @@ class RowGatherer final
/**
* Asynchronous version of LinOp::apply.
*
* It is asynchronous only wrt. the calling thread. Multiple calls to this
* function will execute in order, they are not asynchronous with each
* other.
* @warning Only one mpi::request can be active at any given time. This
* function will throw if another request is already active.
*
* @param b the input distributed::Vector
* @param x the output matrix::Dense with the rows gathered from b
* @return a future for this task. The task is guarantueed to completed
* after `.wait()` has been called on the future.
* @return a mpi::request for this task. The task is guaranteed to
* be completed only after `.wait()` has been called on it.
*/
std::future<void> apply_async(ptr_param<const LinOp> b,
ptr_param<LinOp> x) const;
mpi::request apply_async(ptr_param<const LinOp> b,
ptr_param<LinOp> x) const;

/**
* Asynchronous version of LinOp::apply.
*
* @warning Calling this multiple times with the same workspace and without
* waiting on each previous request will lead to incorrect
* data transfers.
*
* @param b the input distributed::Vector
* @param x the output matrix::Dense with the rows gathered from b
* @param workspace a workspace to store temporary data for the operation.
* This might not be modified before the request is
* waited on.
* @return a mpi::request for this task. The task is guaranteed to
* be completed only after `.wait()` has been called on it.
*/
mpi::request apply_async(ptr_param<const LinOp> b, ptr_param<LinOp> x,
array<char>& workspace) const;

/**
* Get the used collective communicator.
Expand Down Expand Up @@ -158,11 +175,9 @@ class RowGatherer final

array<LocalIndexType> send_idxs_;

detail::AnyDenseCache send_buffer_;
detail::AnyDenseCache recv_buffer_;
mutable array<char> send_workspace_;

mutable int64 current_id_{0};
mutable std::atomic<int64> active_id_{0};
mutable MPI_Request req_listener_{MPI_REQUEST_NULL};
};


Expand Down

0 comments on commit 49dfd97

Please sign in to comment.