-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dist-rg] adds distributed row-gatherer
- Loading branch information
1 parent
1f49b91
commit f75b61f
Showing
4 changed files
with
390 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#include "ginkgo/core/distributed/row_gatherer.hpp" | ||
|
||
#include <ginkgo/core/base/dense_cache.hpp> | ||
#include <ginkgo/core/base/precision_dispatch.hpp> | ||
#include <ginkgo/core/distributed/dense_communicator.hpp> | ||
#include <ginkgo/core/distributed/neighborhood_communicator.hpp> | ||
#include <ginkgo/core/matrix/dense.hpp> | ||
|
||
#include "core/base/dispatch_helper.hpp" | ||
|
||
namespace gko { | ||
namespace experimental { | ||
namespace distributed { | ||
|
||
|
||
#if GINKGO_HAVE_OPENMPI_POST_4_1_X | ||
using DefaultCollComm = mpi::NeighborhoodCommunicator; | ||
#else | ||
using DefaultCollComm = mpi::DenseCommunicator; | ||
#endif | ||
|
||
|
||
template <typename LocalIndexType> | ||
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* b, LinOp* x) const | ||
{ | ||
apply_async(b, x).wait(); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* alpha, const LinOp* b, | ||
const LinOp* beta, LinOp* x) const | ||
GKO_NOT_IMPLEMENTED; | ||
|
||
|
||
template <typename LocalIndexType> | ||
std::future<void> RowGatherer<LocalIndexType>::apply_async( | ||
ptr_param<const LinOp> b, ptr_param<LinOp> x) const | ||
{ | ||
auto op = [b = b.get(), x = x.get(), rg = this->shared_from_this(), | ||
id = current_id_++] { | ||
// ensure that the communications are executed in the order | ||
// the apply_async were called | ||
while (id > rg->active_id_.load()) { | ||
std::this_thread::yield(); | ||
} | ||
|
||
// dispatch global vector | ||
run<Vector, double, float, std::complex<double>, std::complex<float>>( | ||
b, [&](const auto* b_global) { | ||
using ValueType = | ||
typename std::decay_t<decltype(*b_global)>::value_type; | ||
// dispatch local vector with the same precision as the global | ||
// vector | ||
::gko::precision_dispatch<ValueType>( | ||
[&](auto* x_local) { | ||
auto exec = rg->get_executor(); | ||
|
||
auto b_local = b_global->get_local_vector(); | ||
rg->send_buffer.template init<ValueType>( | ||
b_local->get_executor(), | ||
dim<2>(rg->coll_comm_->get_send_size(), | ||
b_local->get_size()[1])); | ||
rg->send_buffer.template get<ValueType>()->fill(0.0); | ||
b_local->row_gather( | ||
&rg->send_idxs_, | ||
rg->send_buffer.template get<ValueType>()); | ||
|
||
auto recv_ptr = x_local->get_values(); | ||
auto send_ptr = | ||
rg->send_buffer.template get<ValueType>() | ||
->get_values(); | ||
|
||
exec->synchronize(); | ||
mpi::contiguous_type type( | ||
b_local->get_size()[1], | ||
mpi::type_impl<ValueType>::get_type()); | ||
auto g = exec->get_scoped_device_id_guard(); | ||
auto req = rg->coll_comm_->i_all_to_all_v( | ||
exec, send_ptr, type.get(), recv_ptr, type.get()); | ||
req.wait(); | ||
}, | ||
x); | ||
}); | ||
|
||
rg->active_id_++; | ||
}; | ||
return std::async(std::launch::async, op); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
std::shared_ptr<const mpi::CollectiveCommunicator> | ||
RowGatherer<LocalIndexType>::get_collective_communicator() const | ||
{ | ||
return coll_comm_; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
template <typename GlobalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer( | ||
std::shared_ptr<const Executor> exec, | ||
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm, | ||
const index_map<LocalIndexType, GlobalIndexType>& imap) | ||
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>( | ||
exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}), | ||
DistributedBase(coll_comm->get_base_communicator()), | ||
coll_comm_(std::move(coll_comm)), | ||
send_idxs_(exec) | ||
{ | ||
// check that the coll_comm_ and imap have the same recv size | ||
// the same check for the send size is not possible, since the | ||
// imap doesn't store send indices | ||
GKO_THROW_IF_INVALID( | ||
coll_comm_->get_recv_size() == imap.get_non_local_size(), | ||
"The collective communicator doesn't match the index map."); | ||
|
||
auto comm = coll_comm_->get_base_communicator(); | ||
auto inverse_comm = coll_comm_->create_inverse(); | ||
|
||
send_idxs_.resize_and_reset(coll_comm_->get_send_size()); | ||
inverse_comm | ||
->i_all_to_all_v(exec, | ||
imap.get_remote_local_idxs().get_const_flat_data(), | ||
send_idxs_.get_data()) | ||
.wait(); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(std::shared_ptr<const Executor> exec, | ||
mpi::communicator comm) | ||
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(exec), | ||
DistributedBase(comm), | ||
coll_comm_(std::make_shared<DefaultCollComm>(comm)), | ||
send_idxs_(exec) | ||
{} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept | ||
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()), | ||
DistributedBase(o.get_communicator()), | ||
send_idxs_(o.get_executor()) | ||
{ | ||
*this = std::move(o); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=( | ||
const RowGatherer& o) | ||
{ | ||
if (this != &o) { | ||
this->set_size(o.get_size()); | ||
coll_comm_ = o.coll_comm_; | ||
send_idxs_ = o.send_idxs_; | ||
} | ||
return *this; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=( | ||
RowGatherer&& o) | ||
{ | ||
if (this != &o) { | ||
this->set_size(o.get_size()); | ||
o.set_size({}); | ||
coll_comm_ = std::exchange( | ||
o.coll_comm_, | ||
std::make_shared<DefaultCollComm>(o.get_communicator())); | ||
send_idxs_ = std::move(o.send_idxs_); | ||
} | ||
return *this; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(const RowGatherer& o) | ||
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()), | ||
DistributedBase(o.get_communicator()), | ||
send_idxs_(o.get_executor()) | ||
{ | ||
*this = o; | ||
} | ||
|
||
|
||
#define GKO_DECLARE_ROW_GATHERER(_itype) class RowGatherer<_itype> | ||
|
||
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_ROW_GATHERER); | ||
|
||
#undef GKO_DECLARE_ROW_GATHERER | ||
|
||
|
||
#define GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR(_ltype, _gtype) \ | ||
RowGatherer<_ltype>::RowGatherer( \ | ||
std::shared_ptr<const Executor> exec, \ | ||
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm, \ | ||
const index_map<_ltype, _gtype>& imap) | ||
|
||
GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( | ||
GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR); | ||
|
||
#undef GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR | ||
|
||
|
||
} // namespace distributed | ||
} // namespace experimental | ||
} // namespace gko |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#ifndef GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ | ||
#define GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ | ||
|
||
|
||
#include <ginkgo/config.hpp> | ||
|
||
|
||
#if GINKGO_BUILD_MPI | ||
|
||
|
||
#include <future> | ||
|
||
#include <ginkgo/core/base/dense_cache.hpp> | ||
#include <ginkgo/core/base/mpi.hpp> | ||
#include <ginkgo/core/distributed/collective_communicator.hpp> | ||
#include <ginkgo/core/distributed/index_map.hpp> | ||
#include <ginkgo/core/distributed/lin_op.hpp> | ||
|
||
|
||
namespace gko { | ||
namespace experimental { | ||
namespace distributed { | ||
|
||
|
||
/** | ||
* The distributed::RowGatherer gathers the rows of distributed::Vector that | ||
* are located on other processes. | ||
* | ||
* Example usage: | ||
* ```c++ | ||
* auto coll_comm = std::make_shared<mpi::neighborhood_communicator>(comm, | ||
* imap); auto rg = distributed::RowGatherer<int32>::create(exec, coll_comm, | ||
* imap); | ||
* | ||
* auto b = distributed::Vector<double>::create(...); | ||
* auto x = matrix::Dense<double>::create(...); | ||
* | ||
* auto future = rg->apply_async(b, x); | ||
* // do some computation that doesn't modify b, or access x | ||
* future.wait(); | ||
* // x now contains the gathered rows of b | ||
* ``` | ||
* Using the apply instead of the apply_async will lead to a blocking | ||
* communication. | ||
* | ||
* @note Objects of this class are only available as shared_ptr, since the class | ||
* is derived from std::enable_shared_from_this. | ||
* | ||
* @tparam LocalIndexType the index type for the stored indices | ||
*/ | ||
template <typename LocalIndexType = int32> | ||
class RowGatherer final | ||
: public EnableDistributedLinOp<RowGatherer<LocalIndexType>>, | ||
public DistributedBase, | ||
public std::enable_shared_from_this<RowGatherer<LocalIndexType>> { | ||
friend class EnableDistributedPolymorphicObject<RowGatherer, LinOp>; | ||
|
||
public: | ||
/** | ||
* Asynchronous version of LinOp::apply. | ||
* | ||
* It is asynchronous only wrt. the calling thread. Multiple calls to this | ||
* function will execute in order, they are not asynchronous with each | ||
* other. | ||
* | ||
* @param b the input distributed::Vector | ||
* @param x the output matrix::Dense with the rows gathered from b | ||
* @return a future for this task. The task is guarantueed to completed | ||
* after `.wait()` has been called on the future. | ||
*/ | ||
std::future<void> apply_async(ptr_param<const LinOp> b, | ||
ptr_param<LinOp> x) const; | ||
|
||
/** | ||
* Get the used collective communicator. | ||
* | ||
* @return the used collective communicator | ||
*/ | ||
std::shared_ptr<const mpi::CollectiveCommunicator> | ||
get_collective_communicator() const; | ||
|
||
/** | ||
* Creates a distributed::RowGatherer from a given collective communicator | ||
* and index map. | ||
* | ||
* @TODO: using a segmented array instead of the imap would probably be | ||
* more general | ||
* | ||
* @tparam GlobalIndexType the global index type of the index map | ||
* | ||
* @param exec the executor | ||
* @param coll_comm the collective communicator | ||
* @param imap the index map defining which rows to gather | ||
* | ||
* @note The coll_comm and imap have to be compatible. The coll_comm must | ||
* send and recv exactly as many rows as the imap defines. | ||
* | ||
* @return a shared_ptr to the created distributed::RowGatherer | ||
*/ | ||
template <typename GlobalIndexType = int64, | ||
typename = std::enable_if_t<sizeof(GlobalIndexType) >= | ||
sizeof(LocalIndexType)>> | ||
static std::shared_ptr<RowGatherer> create( | ||
std::shared_ptr<const Executor> exec, | ||
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm, | ||
const index_map<LocalIndexType, GlobalIndexType>& imap) | ||
{ | ||
return std::shared_ptr<RowGatherer>( | ||
new RowGatherer(std::move(exec), std::move(coll_comm), imap)); | ||
} | ||
|
||
/* | ||
* Create method for an empty RowGatherer. | ||
*/ | ||
static std::shared_ptr<RowGatherer> create( | ||
std::shared_ptr<const Executor> exec, mpi::communicator comm) | ||
{ | ||
return std::shared_ptr<RowGatherer>( | ||
new RowGatherer(std::move(exec), std::move(comm))); | ||
} | ||
|
||
RowGatherer(const RowGatherer& o); | ||
|
||
RowGatherer(RowGatherer&& o) noexcept; | ||
|
||
RowGatherer& operator=(const RowGatherer& o); | ||
|
||
RowGatherer& operator=(RowGatherer&& o); | ||
|
||
protected: | ||
void apply_impl(const LinOp* b, LinOp* x) const override; | ||
|
||
void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, | ||
LinOp* x) const override; | ||
|
||
private: | ||
/** | ||
* @copydoc RowGatherer::create(std::shared_ptr<const | ||
* Executor>, std::shared_ptr<const mpi::collective_communicator>, | ||
* const index_map<LocalIndexType, GlobalIndexType>&) | ||
*/ | ||
template <typename GlobalIndexType> | ||
RowGatherer(std::shared_ptr<const Executor> exec, | ||
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm, | ||
const index_map<LocalIndexType, GlobalIndexType>& imap); | ||
|
||
/** | ||
* @copydoc RowGatherer::create(std::shared_ptr<const | ||
* Executor>, mpi::communicator) | ||
*/ | ||
RowGatherer(std::shared_ptr<const Executor> exec, mpi::communicator comm); | ||
|
||
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm_; | ||
|
||
array<LocalIndexType> send_idxs_; | ||
|
||
detail::AnyDenseCache send_buffer; | ||
|
||
mutable int64 current_id_{0}; | ||
mutable std::atomic<int64> active_id_{0}; | ||
}; | ||
|
||
|
||
} // namespace distributed | ||
} // namespace experimental | ||
} // namespace gko | ||
|
||
#endif | ||
#endif // GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ |
Oops, something went wrong.