From 74d8f03aee85c222d7a7b93fcaf2636bac80d1a7 Mon Sep 17 00:00:00 2001 From: Marcel Koch Date: Tue, 26 Apr 2022 18:25:05 +0200 Subject: [PATCH] wip --- .../distributed/vector_kernels.hpp.inc | 12 ++--- cuda/distributed/vector_kernels.cu | 6 --- dpcpp/distributed/vector_kernels.dp.cpp | 47 ++++++++++++------- hip/distributed/vector_kernels.hip.cpp | 6 --- 4 files changed, 35 insertions(+), 36 deletions(-) diff --git a/common/cuda_hip/distributed/vector_kernels.hpp.inc b/common/cuda_hip/distributed/vector_kernels.hpp.inc index f21bbb2d706..504f7153062 100644 --- a/common/cuda_hip/distributed/vector_kernels.hpp.inc +++ b/common/cuda_hip/distributed/vector_kernels.hpp.inc @@ -58,7 +58,7 @@ void build_local( // array // the flat_idx_it is used by the scatter_if as an index map for the values auto map_to_local_row = - [range_bounds, range_starting_indices] GKO_THRUST_LAMBDA( + [range_bounds, range_starting_indices] __host__ __device__( const thrust::tuple& idx_range_id) -> LocalIndexType { const auto idx = thrust::get<0>(idx_range_id); @@ -73,7 +73,7 @@ void build_local( auto stride = local_mtx->get_stride(); auto map_to_flat_idx = - [stride] GKO_THRUST_LAMBDA( + [stride] __host__ __device__( const thrust::tuple& row_col) -> size_type { return thrust::get<0>(row_col) * stride + thrust::get<1>(row_col); @@ -83,10 +83,10 @@ void build_local( thrust::make_tuple(local_row_it, input.get_const_col_idxs())), map_to_flat_idx); - auto is_local_row = [part_ids, - local_part] GKO_THRUST_LAMBDA(const size_type rid) { - return part_ids[rid] == local_part; - }; + auto is_local_row = + [part_ids, local_part] __host__ __device__(const size_type rid) { + return part_ids[rid] == local_part; + }; thrust::scatter_if(thrust::device, input.get_const_values(), input.get_const_values() + input.get_num_elems(), flat_idx_it, range_id.get_data(), diff --git a/cuda/distributed/vector_kernels.cu b/cuda/distributed/vector_kernels.cu index 168bc4eabdc..def3fc8ec87 100644 --- a/cuda/distributed/vector_kernels.cu +++ b/cuda/distributed/vector_kernels.cu @@ -50,15 +50,9 @@ namespace cuda { namespace distributed_vector { -#define GKO_THRUST_LAMBDA __device__ - - #include "common/cuda_hip/distributed/vector_kernels.hpp.inc" -#undef GKO_THRUST_LAMBDA - - } // namespace distributed_vector } // namespace cuda } // namespace kernels diff --git a/dpcpp/distributed/vector_kernels.dp.cpp b/dpcpp/distributed/vector_kernels.dp.cpp index 45ffed4c1ed..6534418b454 100644 --- a/dpcpp/distributed/vector_kernels.dp.cpp +++ b/dpcpp/distributed/vector_kernels.dp.cpp @@ -30,11 +30,16 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *************************************************************/ -#include "core/distributed/vector_kernels.hpp" - +// force-top: on +// oneDPL needs to be first to avoid issues with libstdc++ TBB impl #include +#include #include +// force-top: off + + +#include "core/distributed/vector_kernels.hpp" #include @@ -81,25 +86,31 @@ void build_local( return static_cast(idx - range_bounds[rid]) + range_starting_indices[rid]; }; - auto local_row_it = oneapi::dpl::make_transform_iterator( - oneapi::dpl::make_zip_iterator(input.get_const_row_idxs(), - range_id.get_data()), - map_to_local_row); - - auto flat_idx_it = oneapi::dpl::make_permutation_iterator( - local_mtx->get_values(), - [local_row_it, cols = input.get_const_col_idxs(), - stride = local_mtx->get_stride()](const auto i) { - return local_row_it[i] * stride + cols[i]; + Array flat_idx_map{exec, input.get_num_elems()}; + auto zip_it = oneapi::dpl::make_zip_iterator(input.get_const_row_idxs(), + input.get_const_col_idxs(), + range_id.get_const_data()); + oneapi::dpl::transform( + policy, zip_it, zip_it + input.get_num_elems(), flat_idx_map.get_data(), + [cols = input.get_const_col_idxs(), stride = local_mtx->get_stride(), + map_to_local_row](const auto t) { + auto [row, col, rid] = t; + auto local_row = map_to_local_row(std::make_tuple(row, rid)); + return local_row * stride + col; }); + auto flat_idx_it = oneapi::dpl::make_permutation_iterator( + local_mtx->get_values(), flat_idx_map.get_data()); - auto is_local_row = [range_id = range_id.get_data(), part_ids, - local_part](const auto i) { - return part_ids[range_id[i]] == local_part; + auto is_local_row = [part_ids, local_part](const auto t) { + return part_ids[std::get<1>(t)] == local_part; }; - oneapi::dpl::copy_if(policy, input.get_const_values(), - input.get_const_values() + input.get_num_elems(), - flat_idx_it, is_local_row); + auto value_rid_it = oneapi::dpl::make_zip_iterator( + input.get_const_values(), range_id.get_const_data()); + oneapi::dpl::copy_if(policy, value_rid_it, + value_rid_it + input.get_num_elems(), + oneapi::dpl::make_zip_iterator( + flat_idx_it, oneapi::dpl::discard_iterator()), + is_local_row); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE( diff --git a/hip/distributed/vector_kernels.hip.cpp b/hip/distributed/vector_kernels.hip.cpp index bbc7ee1eb1b..6cbfa1224e9 100644 --- a/hip/distributed/vector_kernels.hip.cpp +++ b/hip/distributed/vector_kernels.hip.cpp @@ -53,15 +53,9 @@ namespace hip { namespace distributed_vector { -#define GKO_THRUST_LAMBDA __device__ __host__ - - #include "common/cuda_hip/distributed/vector_kernels.hpp.inc" -#undef GKO_THRUST_LAMBDA - - } // namespace distributed_vector } // namespace hip } // namespace kernels