Skip to content

Commit

Permalink
[dist-rg] adds distributed row-gatherer
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Aug 16, 2024
1 parent 1f49b91 commit f75b61f
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ if(GINKGO_BUILD_MPI)
distributed/matrix.cpp
distributed/neighborhood_communicator.cpp
distributed/partition_helpers.cpp
distributed/row_gatherer.cpp
distributed/vector.cpp
distributed/preconditioner/schwarz.cpp)
endif()
Expand Down
215 changes: 215 additions & 0 deletions core/distributed/row_gatherer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "ginkgo/core/distributed/row_gatherer.hpp"

#include <ginkgo/core/base/dense_cache.hpp>
#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/distributed/dense_communicator.hpp>
#include <ginkgo/core/distributed/neighborhood_communicator.hpp>
#include <ginkgo/core/matrix/dense.hpp>

#include "core/base/dispatch_helper.hpp"

namespace gko {
namespace experimental {
namespace distributed {


#if GINKGO_HAVE_OPENMPI_POST_4_1_X
using DefaultCollComm = mpi::NeighborhoodCommunicator;
#else
using DefaultCollComm = mpi::DenseCommunicator;
#endif


template <typename LocalIndexType>
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* b, LinOp* x) const
{
apply_async(b, x).wait();
}


template <typename LocalIndexType>
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* alpha, const LinOp* b,
const LinOp* beta, LinOp* x) const
GKO_NOT_IMPLEMENTED;


template <typename LocalIndexType>
std::future<void> RowGatherer<LocalIndexType>::apply_async(
ptr_param<const LinOp> b, ptr_param<LinOp> x) const
{
auto op = [b = b.get(), x = x.get(), rg = this->shared_from_this(),
id = current_id_++] {
// ensure that the communications are executed in the order
// the apply_async were called
while (id > rg->active_id_.load()) {
std::this_thread::yield();
}

// dispatch global vector
run<Vector, double, float, std::complex<double>, std::complex<float>>(
b, [&](const auto* b_global) {
using ValueType =
typename std::decay_t<decltype(*b_global)>::value_type;
// dispatch local vector with the same precision as the global
// vector
::gko::precision_dispatch<ValueType>(
[&](auto* x_local) {
auto exec = rg->get_executor();

auto b_local = b_global->get_local_vector();
rg->send_buffer.template init<ValueType>(
b_local->get_executor(),
dim<2>(rg->coll_comm_->get_send_size(),
b_local->get_size()[1]));
rg->send_buffer.template get<ValueType>()->fill(0.0);
b_local->row_gather(
&rg->send_idxs_,
rg->send_buffer.template get<ValueType>());

auto recv_ptr = x_local->get_values();
auto send_ptr =
rg->send_buffer.template get<ValueType>()
->get_values();

exec->synchronize();
mpi::contiguous_type type(
b_local->get_size()[1],
mpi::type_impl<ValueType>::get_type());
auto g = exec->get_scoped_device_id_guard();
auto req = rg->coll_comm_->i_all_to_all_v(
exec, send_ptr, type.get(), recv_ptr, type.get());
req.wait();
},
x);
});

rg->active_id_++;
};
return std::async(std::launch::async, op);
}


template <typename LocalIndexType>
std::shared_ptr<const mpi::CollectiveCommunicator>
RowGatherer<LocalIndexType>::get_collective_communicator() const
{
return coll_comm_;
}


template <typename LocalIndexType>
template <typename GlobalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(
std::shared_ptr<const Executor> exec,
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm,
const index_map<LocalIndexType, GlobalIndexType>& imap)
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(
exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}),
DistributedBase(coll_comm->get_base_communicator()),
coll_comm_(std::move(coll_comm)),
send_idxs_(exec)
{
// check that the coll_comm_ and imap have the same recv size
// the same check for the send size is not possible, since the
// imap doesn't store send indices
GKO_THROW_IF_INVALID(
coll_comm_->get_recv_size() == imap.get_non_local_size(),
"The collective communicator doesn't match the index map.");

auto comm = coll_comm_->get_base_communicator();
auto inverse_comm = coll_comm_->create_inverse();

send_idxs_.resize_and_reset(coll_comm_->get_send_size());
inverse_comm
->i_all_to_all_v(exec,
imap.get_remote_local_idxs().get_const_flat_data(),
send_idxs_.get_data())
.wait();
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(std::shared_ptr<const Executor> exec,
mpi::communicator comm)
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(exec),
DistributedBase(comm),
coll_comm_(std::make_shared<DefaultCollComm>(comm)),
send_idxs_(exec)
{}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()),
DistributedBase(o.get_communicator()),
send_idxs_(o.get_executor())
{
*this = std::move(o);
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
const RowGatherer& o)
{
if (this != &o) {
this->set_size(o.get_size());
coll_comm_ = o.coll_comm_;
send_idxs_ = o.send_idxs_;
}
return *this;
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
RowGatherer&& o)
{
if (this != &o) {
this->set_size(o.get_size());
o.set_size({});
coll_comm_ = std::exchange(
o.coll_comm_,
std::make_shared<DefaultCollComm>(o.get_communicator()));
send_idxs_ = std::move(o.send_idxs_);
}
return *this;
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(const RowGatherer& o)
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()),
DistributedBase(o.get_communicator()),
send_idxs_(o.get_executor())
{
*this = o;
}


#define GKO_DECLARE_ROW_GATHERER(_itype) class RowGatherer<_itype>

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_ROW_GATHERER);

#undef GKO_DECLARE_ROW_GATHERER


#define GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR(_ltype, _gtype) \
RowGatherer<_ltype>::RowGatherer( \
std::shared_ptr<const Executor> exec, \
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm, \
const index_map<_ltype, _gtype>& imap)

GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR);

#undef GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR


} // namespace distributed
} // namespace experimental
} // namespace gko
173 changes: 173 additions & 0 deletions include/ginkgo/core/distributed/row_gatherer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_
#define GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_


#include <ginkgo/config.hpp>


#if GINKGO_BUILD_MPI


#include <future>

#include <ginkgo/core/base/dense_cache.hpp>
#include <ginkgo/core/base/mpi.hpp>
#include <ginkgo/core/distributed/collective_communicator.hpp>
#include <ginkgo/core/distributed/index_map.hpp>
#include <ginkgo/core/distributed/lin_op.hpp>


namespace gko {
namespace experimental {
namespace distributed {


/**
* The distributed::RowGatherer gathers the rows of distributed::Vector that
* are located on other processes.
*
* Example usage:
* ```c++
* auto coll_comm = std::make_shared<mpi::neighborhood_communicator>(comm,
* imap); auto rg = distributed::RowGatherer<int32>::create(exec, coll_comm,
* imap);
*
* auto b = distributed::Vector<double>::create(...);
* auto x = matrix::Dense<double>::create(...);
*
* auto future = rg->apply_async(b, x);
* // do some computation that doesn't modify b, or access x
* future.wait();
* // x now contains the gathered rows of b
* ```
* Using the apply instead of the apply_async will lead to a blocking
* communication.
*
* @note Objects of this class are only available as shared_ptr, since the class
* is derived from std::enable_shared_from_this.
*
* @tparam LocalIndexType the index type for the stored indices
*/
template <typename LocalIndexType = int32>
class RowGatherer final
: public EnableDistributedLinOp<RowGatherer<LocalIndexType>>,
public DistributedBase,
public std::enable_shared_from_this<RowGatherer<LocalIndexType>> {
friend class EnableDistributedPolymorphicObject<RowGatherer, LinOp>;

public:
/**
* Asynchronous version of LinOp::apply.
*
* It is asynchronous only wrt. the calling thread. Multiple calls to this
* function will execute in order, they are not asynchronous with each
* other.
*
* @param b the input distributed::Vector
* @param x the output matrix::Dense with the rows gathered from b
* @return a future for this task. The task is guarantueed to completed
* after `.wait()` has been called on the future.
*/
std::future<void> apply_async(ptr_param<const LinOp> b,
ptr_param<LinOp> x) const;

/**
* Get the used collective communicator.
*
* @return the used collective communicator
*/
std::shared_ptr<const mpi::CollectiveCommunicator>
get_collective_communicator() const;

/**
* Creates a distributed::RowGatherer from a given collective communicator
* and index map.
*
* @TODO: using a segmented array instead of the imap would probably be
* more general
*
* @tparam GlobalIndexType the global index type of the index map
*
* @param exec the executor
* @param coll_comm the collective communicator
* @param imap the index map defining which rows to gather
*
* @note The coll_comm and imap have to be compatible. The coll_comm must
* send and recv exactly as many rows as the imap defines.
*
* @return a shared_ptr to the created distributed::RowGatherer
*/
template <typename GlobalIndexType = int64,
typename = std::enable_if_t<sizeof(GlobalIndexType) >=
sizeof(LocalIndexType)>>
static std::shared_ptr<RowGatherer> create(
std::shared_ptr<const Executor> exec,
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm,
const index_map<LocalIndexType, GlobalIndexType>& imap)
{
return std::shared_ptr<RowGatherer>(
new RowGatherer(std::move(exec), std::move(coll_comm), imap));
}

/*
* Create method for an empty RowGatherer.
*/
static std::shared_ptr<RowGatherer> create(
std::shared_ptr<const Executor> exec, mpi::communicator comm)
{
return std::shared_ptr<RowGatherer>(
new RowGatherer(std::move(exec), std::move(comm)));
}

RowGatherer(const RowGatherer& o);

RowGatherer(RowGatherer&& o) noexcept;

RowGatherer& operator=(const RowGatherer& o);

RowGatherer& operator=(RowGatherer&& o);

protected:
void apply_impl(const LinOp* b, LinOp* x) const override;

void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
LinOp* x) const override;

private:
/**
* @copydoc RowGatherer::create(std::shared_ptr<const
* Executor>, std::shared_ptr<const mpi::collective_communicator>,
* const index_map<LocalIndexType, GlobalIndexType>&)
*/
template <typename GlobalIndexType>
RowGatherer(std::shared_ptr<const Executor> exec,
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm,
const index_map<LocalIndexType, GlobalIndexType>& imap);

/**
* @copydoc RowGatherer::create(std::shared_ptr<const
* Executor>, mpi::communicator)
*/
RowGatherer(std::shared_ptr<const Executor> exec, mpi::communicator comm);

std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm_;

array<LocalIndexType> send_idxs_;

detail::AnyDenseCache send_buffer;

mutable int64 current_id_{0};
mutable std::atomic<int64> active_id_{0};
};


} // namespace distributed
} // namespace experimental
} // namespace gko

#endif
#endif // GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_
Loading

0 comments on commit f75b61f

Please sign in to comment.