Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Apr 26, 2022
1 parent 7db9e12 commit 74d8f03
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 36 deletions.
12 changes: 6 additions & 6 deletions common/cuda_hip/distributed/vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void build_local(
// array
// the flat_idx_it is used by the scatter_if as an index map for the values
auto map_to_local_row =
[range_bounds, range_starting_indices] GKO_THRUST_LAMBDA(
[range_bounds, range_starting_indices] __host__ __device__(
const thrust::tuple<GlobalIndexType, size_type>& idx_range_id)
-> LocalIndexType {
const auto idx = thrust::get<0>(idx_range_id);
Expand All @@ -73,7 +73,7 @@ void build_local(

auto stride = local_mtx->get_stride();
auto map_to_flat_idx =
[stride] GKO_THRUST_LAMBDA(
[stride] __host__ __device__(
const thrust::tuple<LocalIndexType, GlobalIndexType>& row_col)
-> size_type {
return thrust::get<0>(row_col) * stride + thrust::get<1>(row_col);
Expand All @@ -83,10 +83,10 @@ void build_local(
thrust::make_tuple(local_row_it, input.get_const_col_idxs())),
map_to_flat_idx);

auto is_local_row = [part_ids,
local_part] GKO_THRUST_LAMBDA(const size_type rid) {
return part_ids[rid] == local_part;
};
auto is_local_row =
[part_ids, local_part] __host__ __device__(const size_type rid) {
return part_ids[rid] == local_part;
};
thrust::scatter_if(thrust::device, input.get_const_values(),
input.get_const_values() + input.get_num_elems(),
flat_idx_it, range_id.get_data(),
Expand Down
6 changes: 0 additions & 6 deletions cuda/distributed/vector_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,9 @@ namespace cuda {
namespace distributed_vector {


#define GKO_THRUST_LAMBDA __device__


#include "common/cuda_hip/distributed/vector_kernels.hpp.inc"


#undef GKO_THRUST_LAMBDA


} // namespace distributed_vector
} // namespace cuda
} // namespace kernels
Expand Down
47 changes: 29 additions & 18 deletions dpcpp/distributed/vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,16 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/

#include "core/distributed/vector_kernels.hpp"


// force-top: on
// oneDPL needs to be first to avoid issues with libstdc++ TBB impl
#include <oneapi/dpl/algorithm>
#include <oneapi/dpl/execution>
#include <oneapi/dpl/iterator>
// force-top: off


#include "core/distributed/vector_kernels.hpp"


#include <ginkgo/core/base/exception_helpers.hpp>
Expand Down Expand Up @@ -81,25 +86,31 @@ void build_local(
return static_cast<LocalIndexType>(idx - range_bounds[rid]) +
range_starting_indices[rid];
};
auto local_row_it = oneapi::dpl::make_transform_iterator(
oneapi::dpl::make_zip_iterator(input.get_const_row_idxs(),
range_id.get_data()),
map_to_local_row);

auto flat_idx_it = oneapi::dpl::make_permutation_iterator(
local_mtx->get_values(),
[local_row_it, cols = input.get_const_col_idxs(),
stride = local_mtx->get_stride()](const auto i) {
return local_row_it[i] * stride + cols[i];
Array<size_type> flat_idx_map{exec, input.get_num_elems()};
auto zip_it = oneapi::dpl::make_zip_iterator(input.get_const_row_idxs(),
input.get_const_col_idxs(),
range_id.get_const_data());
oneapi::dpl::transform(
policy, zip_it, zip_it + input.get_num_elems(), flat_idx_map.get_data(),
[cols = input.get_const_col_idxs(), stride = local_mtx->get_stride(),
map_to_local_row](const auto t) {
auto [row, col, rid] = t;
auto local_row = map_to_local_row(std::make_tuple(row, rid));
return local_row * stride + col;
});
auto flat_idx_it = oneapi::dpl::make_permutation_iterator(
local_mtx->get_values(), flat_idx_map.get_data());

auto is_local_row = [range_id = range_id.get_data(), part_ids,
local_part](const auto i) {
return part_ids[range_id[i]] == local_part;
auto is_local_row = [part_ids, local_part](const auto t) {
return part_ids[std::get<1>(t)] == local_part;
};
oneapi::dpl::copy_if(policy, input.get_const_values(),
input.get_const_values() + input.get_num_elems(),
flat_idx_it, is_local_row);
auto value_rid_it = oneapi::dpl::make_zip_iterator(
input.get_const_values(), range_id.get_const_data());
oneapi::dpl::copy_if(policy, value_rid_it,
value_rid_it + input.get_num_elems(),
oneapi::dpl::make_zip_iterator(
flat_idx_it, oneapi::dpl::discard_iterator()),
is_local_row);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
Expand Down
6 changes: 0 additions & 6 deletions hip/distributed/vector_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,9 @@ namespace hip {
namespace distributed_vector {


#define GKO_THRUST_LAMBDA __device__ __host__


#include "common/cuda_hip/distributed/vector_kernels.hpp.inc"


#undef GKO_THRUST_LAMBDA


} // namespace distributed_vector
} // namespace hip
} // namespace kernels
Expand Down

0 comments on commit 74d8f03

Please sign in to comment.