Skip to content

Commit

Permalink
[pgm] use row-gatherer from matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Aug 16, 2024
1 parent 03d350d commit 67e8321
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
7 changes: 7 additions & 0 deletions core/distributed/row_gatherer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ RowGatherer<LocalIndexType>::RowGatherer(
}


template <typename LocalIndexType>
const LocalIndexType* RowGatherer<LocalIndexType>::get_const_row_idxs() const
{
return send_idxs_.get_const_data();
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(std::shared_ptr<const Executor> exec,
mpi::communicator comm)
Expand Down
21 changes: 9 additions & 12 deletions core/multigrid/pgm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,15 @@ array<GlobalIndexType> Pgm<ValueType, IndexType>::communicate_non_local_agg(
{
auto exec = gko::as<LinOp>(matrix)->get_executor();
const auto comm = matrix->get_communicator();
auto send_sizes = matrix->send_sizes_;
auto recv_sizes = matrix->recv_sizes_;
auto send_offsets = matrix->send_offsets_;
auto recv_offsets = matrix->recv_offsets_;
auto gather_idxs = matrix->gather_idxs_;
auto total_send_size = send_offsets.back();
auto total_recv_size = recv_offsets.back();
auto coll_comm = matrix->row_gatherer_->get_collective_communicator();
auto total_send_size = coll_comm->get_send_size();
auto total_recv_size = coll_comm->get_recv_size();
auto row_gatherer = matrix->row_gatherer_;

array<IndexType> send_agg(exec, total_send_size);
exec->run(pgm::make_gather_index(
send_agg.get_size(), local_agg.get_const_data(),
gather_idxs.get_const_data(), send_agg.get_data()));
row_gatherer->get_const_row_idxs(), send_agg.get_data()));

// temporary index map that contains no remote connections to map
// local indices to global
Expand All @@ -296,16 +293,16 @@ array<GlobalIndexType> Pgm<ValueType, IndexType>::communicate_non_local_agg(
seng_global_agg.get_data(),
host_send_buffer.get_data());
}
auto type = experimental::mpi::type_impl<GlobalIndexType>::get_type();

const auto send_ptr = use_host_buffer ? host_send_buffer.get_const_data()
: seng_global_agg.get_const_data();
auto recv_ptr = use_host_buffer ? host_recv_buffer.get_data()
: non_local_agg.get_data();
exec->synchronize();
comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr,
send_sizes.data(), send_offsets.data(), type, recv_ptr,
recv_sizes.data(), recv_offsets.data(), type);
coll_comm
->i_all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr,
recv_ptr)
.wait();
if (use_host_buffer) {
exec->copy_from(exec->get_master(), total_recv_size, recv_ptr,
non_local_agg.get_data());
Expand Down
7 changes: 7 additions & 0 deletions include/ginkgo/core/distributed/row_gatherer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ class RowGatherer final
std::shared_ptr<const mpi::CollectiveCommunicator>
get_collective_communicator() const;

/**
* Read access to the (local) rows indices
*
* @return the (local) row indices that are gathered
*/
const LocalIndexType* get_const_row_idxs() const;

/**
* Creates a distributed::RowGatherer from a given collective communicator
* and index map.
Expand Down

0 comments on commit 67e8321

Please sign in to comment.