diff --git a/common/cuda_hip/distributed/index_map_kernels.cpp b/common/cuda_hip/distributed/index_map_kernels.cpp index e27c5221013..47a795d1e7b 100644 --- a/common/cuda_hip/distributed/index_map_kernels.cpp +++ b/common/cuda_hip/distributed/index_map_kernels.cpp @@ -296,6 +296,90 @@ GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL); +template +void map_to_global( + std::shared_ptr exec, + device_partition partition, + device_segmented_array remote_global_idxs, + experimental::distributed::comm_index_type rank, + const array& local_ids, + experimental::distributed::index_space is, + array& global_ids) +{ + auto range_bounds = partition.offsets_begin; + auto starting_indices = partition.starting_indices_begin; + const auto& ranges_by_part = partition.ranges_by_part; + auto local_ids_it = local_ids.get_const_data(); + auto input_size = local_ids.get_size(); + + auto policy = thrust_policy(exec); + + global_ids.resize_and_reset(local_ids.get_size()); + auto global_ids_it = global_ids.get_data(); + + auto map_local = [rank, ranges_by_part, range_bounds, starting_indices, + partition] __device__(auto lid) { + auto local_size = + static_cast(partition.part_sizes_begin[rank]); + + if (lid < 0 || lid >= local_size) { + return invalid_index(); + } + + auto local_ranges = ranges_by_part.get_segment(rank); + auto local_ranges_size = + static_cast(local_ranges.end - local_ranges.begin); + + auto it = binary_search(int64(0), local_ranges_size, [=](const auto i) { + return starting_indices[local_ranges.begin[i]] >= lid; + }); + auto local_range_id = + it != local_ranges_size ? it : max(int64(0), it - 1); + auto range_id = local_ranges.begin[local_range_id]; + + return static_cast(lid - starting_indices[range_id]) + + range_bounds[range_id]; + }; + auto map_non_local = [remote_global_idxs] __device__(auto lid) { + auto remote_size = static_cast( + remote_global_idxs.flat_end - remote_global_idxs.flat_begin); + + if (lid < 0 || lid >= remote_size) { + return invalid_index(); + } + + return remote_global_idxs.flat_begin[lid]; + }; + auto map_combined = [map_local, map_non_local, partition, + rank] __device__(auto lid) { + auto local_size = + static_cast(partition.part_sizes_begin[rank]); + + if (lid < local_size) { + return map_local(lid); + } else { + return map_non_local(lid - local_size); + } + }; + + if (is == experimental::distributed::index_space::local) { + thrust::transform(policy, local_ids_it, local_ids_it + input_size, + global_ids_it, map_local); + } + if (is == experimental::distributed::index_space::non_local) { + thrust::transform(policy, local_ids_it, local_ids_it + input_size, + global_ids_it, map_non_local); + } + if (is == experimental::distributed::index_space::combined) { + thrust::transform(policy, local_ids_it, local_ids_it + input_size, + global_ids_it, map_combined); + } +} + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL); + + } // namespace index_map } // namespace GKO_DEVICE_NAMESPACE } // namespace kernels diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 9145341d6f7..fe054856695 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -258,6 +258,7 @@ namespace index_map { GKO_STUB_LOCAL_GLOBAL_TYPE(GKO_DECLARE_INDEX_MAP_BUILD_MAPPING); GKO_STUB_LOCAL_GLOBAL_TYPE(GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL); +GKO_STUB_LOCAL_GLOBAL_TYPE(GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL); } // namespace index_map diff --git a/core/distributed/index_map.cpp b/core/distributed/index_map.cpp index 9f0ed8137ba..89d0f563f18 100644 --- a/core/distributed/index_map.cpp +++ b/core/distributed/index_map.cpp @@ -13,6 +13,7 @@ namespace index_map_kernels { GKO_REGISTER_OPERATION(build_mapping, index_map::build_mapping); GKO_REGISTER_OPERATION(map_to_local, index_map::map_to_local); +GKO_REGISTER_OPERATION(map_to_global, index_map::map_to_global); } // namespace index_map_kernels @@ -89,6 +90,21 @@ array index_map::map_to_local( } +template +array +index_map::map_to_global( + const array& local_ids, index_space index_space_v) const +{ + array global_ids(exec_); + + exec_->run(index_map_kernels::make_map_to_global( + to_device(partition_.get()), to_device(remote_global_idxs_), rank_, + local_ids, index_space_v, global_ids)); + + return global_ids; +} + + template index_map::index_map( std::shared_ptr exec, diff --git a/core/distributed/index_map_kernels.hpp b/core/distributed/index_map_kernels.hpp index 4694ba6cc10..2a69b8f1308 100644 --- a/core/distributed/index_map_kernels.hpp +++ b/core/distributed/index_map_kernels.hpp @@ -13,6 +13,7 @@ #include "core/base/kernel_declaration.hpp" #include "core/base/segmented_array.hpp" +#include "core/distributed/device_partition.hpp" namespace gko { @@ -55,10 +56,13 @@ namespace kernels { * * - partition: the global partition * - remote_target_ids: the owning part ids of each segment of - * remote_global_idxs + * remote_global_idxs * - remote_global_idxs: the remote global indices, segmented by the owning part * ids * - rank: the part id of this process + * + * Any global index that is not in the specified local index space is mapped + * to invalid_index. */ #define GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL(_ltype, _gtype) \ void map_to_local( \ @@ -72,11 +76,36 @@ namespace kernels { experimental::distributed::index_space is, array<_ltype>& local_ids) +/** + * This kernels maps local indices to global indices. + * + * The relevant input parameter from the index map are: + * + * - partition: the global partition + * - remote_global_idxs: the remote global indices, segmented by the owning part + * ids + * - rank: the part id of this process + * + * Any local index that is not part of the specified index space is mapped to + * invalid_index. + */ +#define GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL(_ltype, _gtype) \ + void map_to_global( \ + std::shared_ptr exec, \ + device_partition partition, \ + device_segmented_array remote_global_idxs, \ + experimental::distributed::comm_index_type rank, \ + const array<_ltype>& local_ids, \ + experimental::distributed::index_space is, array<_gtype>& global_ids) + + #define GKO_DECLARE_ALL_AS_TEMPLATES \ template \ GKO_DECLARE_INDEX_MAP_BUILD_MAPPING(LocalIndexType, GlobalIndexType); \ template \ - GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL(LocalIndexType, GlobalIndexType) + GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL(LocalIndexType, GlobalIndexType); \ + template \ + GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL(LocalIndexType, GlobalIndexType) GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(index_map, diff --git a/dpcpp/distributed/index_map_kernels.dp.cpp b/dpcpp/distributed/index_map_kernels.dp.cpp index cf1b28140e1..4f66126ee5c 100644 --- a/dpcpp/distributed/index_map_kernels.dp.cpp +++ b/dpcpp/distributed/index_map_kernels.dp.cpp @@ -44,6 +44,19 @@ GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL); +template +void map_to_global( + std::shared_ptr exec, + device_partition partition, + device_segmented_array remote_global_idxs, + experimental::distributed::comm_index_type rank, + const array& local_ids, + experimental::distributed::index_space is, + array& global_ids) GKO_NOT_IMPLEMENTED; + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL); + } // namespace index_map } // namespace GKO_DEVICE_NAMESPACE } // namespace kernels diff --git a/include/ginkgo/core/distributed/index_map.hpp b/include/ginkgo/core/distributed/index_map.hpp index c2c2473d769..5eca83916e2 100644 --- a/include/ginkgo/core/distributed/index_map.hpp +++ b/include/ginkgo/core/distributed/index_map.hpp @@ -81,6 +81,20 @@ struct index_map { array map_to_local(const array& global_ids, index_space index_space_v) const; + + /** + * Maps local indices to global indices + * + * @param local_ids the local indices to map + * @param index_space_v the index space in which the passed-in local + * indices are defined + * + * @return the mapped global indices. Any local index, that is not in the + * specified index space is mapped to invalid_index + */ + array map_to_global(const array& local_ids, + index_space index_space_v) const; + /** * \brief get size of index_space::local */ diff --git a/omp/distributed/index_map_kernels.cpp b/omp/distributed/index_map_kernels.cpp index 7374f7b978b..7be2ec5e7a7 100644 --- a/omp/distributed/index_map_kernels.cpp +++ b/omp/distributed/index_map_kernels.cpp @@ -239,6 +239,77 @@ GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL); +template +void map_to_global( + std::shared_ptr exec, + device_partition partition, + device_segmented_array remote_global_idxs, + experimental::distributed::comm_index_type rank, + const array& local_ids, + experimental::distributed::index_space is, + array& global_ids) +{ + const auto& ranges_by_part = partition.ranges_by_part; + auto local_ranges = ranges_by_part.get_segment(rank); + + global_ids.resize_and_reset(local_ids.get_size()); + + auto local_size = + static_cast(partition.part_sizes_begin[rank]); + auto remote_size = static_cast( + remote_global_idxs.flat_end - remote_global_idxs.flat_begin); + size_type local_range_id = 0; + if (is == experimental::distributed::index_space::local) { +#pragma omp parallel for firstprivate(local_range_id) + for (size_type i = 0; i < local_ids.get_size(); ++i) { + auto lid = local_ids.get_const_data()[i]; + + if (0 <= lid && lid < local_size) { + local_range_id = + find_local_range(lid, rank, partition, local_range_id); + global_ids.get_data()[i] = map_to_global( + lid, partition, local_ranges.begin[local_range_id]); + } else { + global_ids.get_data()[i] = invalid_index(); + } + } + } + if (is == experimental::distributed::index_space::non_local) { +#pragma omp parallel for + for (size_type i = 0; i < local_ids.get_size(); ++i) { + auto lid = local_ids.get_const_data()[i]; + + if (0 <= lid && lid < remote_size) { + global_ids.get_data()[i] = remote_global_idxs.flat_begin[lid]; + } else { + global_ids.get_data()[i] = invalid_index(); + } + } + } + if (is == experimental::distributed::index_space::combined) { +#pragma omp parallel for firstprivate(local_range_id) + for (size_type i = 0; i < local_ids.get_size(); ++i) { + auto lid = local_ids.get_const_data()[i]; + + if (0 <= lid && lid < local_size) { + local_range_id = + find_local_range(lid, rank, partition, local_range_id); + global_ids.get_data()[i] = map_to_global( + lid, partition, local_ranges.begin[local_range_id]); + } else if (local_size <= lid && lid < local_size + remote_size) { + global_ids.get_data()[i] = + remote_global_idxs.flat_begin[lid - local_size]; + } else { + global_ids.get_data()[i] = invalid_index(); + } + } + } +} + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL); + + } // namespace index_map } // namespace omp } // namespace kernels diff --git a/reference/distributed/index_map_kernels.cpp b/reference/distributed/index_map_kernels.cpp index 322a95c6cdb..a47592e512a 100644 --- a/reference/distributed/index_map_kernels.cpp +++ b/reference/distributed/index_map_kernels.cpp @@ -199,6 +199,77 @@ GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL); +template +void map_to_global( + std::shared_ptr exec, + device_partition partition, + device_segmented_array remote_global_idxs, + experimental::distributed::comm_index_type rank, + const array& local_ids, + experimental::distributed::index_space is, + array& global_ids) +{ + const auto& ranges_by_part = partition.ranges_by_part; + auto local_ranges = ranges_by_part.get_segment(rank); + + global_ids.resize_and_reset(local_ids.get_size()); + + auto local_size = + static_cast(partition.part_sizes_begin[rank]); + size_type local_range_id = 0; + auto map_local = [&](auto lid) { + if (0 <= lid && lid < local_size) { + local_range_id = + find_local_range(lid, rank, partition, local_range_id); + return map_to_global(lid, partition, + local_ranges.begin[local_range_id]); + } else { + return invalid_index(); + } + }; + + auto remote_size = static_cast( + remote_global_idxs.flat_end - remote_global_idxs.flat_begin); + auto map_non_local = [&](auto lid) { + if (0 <= lid && lid < remote_size) { + return remote_global_idxs.flat_begin[lid]; + } else { + return invalid_index(); + } + }; + + auto map_combined = [&](auto lid) { + if (lid < local_size) { + return map_local(lid); + } else { + return map_non_local(lid - local_size); + } + }; + + if (is == experimental::distributed::index_space::local) { + for (size_type i = 0; i < local_ids.get_size(); ++i) { + auto lid = local_ids.get_const_data()[i]; + global_ids.get_data()[i] = map_local(lid); + } + } + if (is == experimental::distributed::index_space::non_local) { + for (size_type i = 0; i < local_ids.get_size(); ++i) { + auto lid = local_ids.get_const_data()[i]; + global_ids.get_data()[i] = map_non_local(lid); + } + } + if (is == experimental::distributed::index_space::combined) { + for (size_type i = 0; i < local_ids.get_size(); ++i) { + auto lid = local_ids.get_const_data()[i]; + global_ids.get_data()[i] = map_combined(lid); + } + } +} + +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( + GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL); + + } // namespace index_map } // namespace reference } // namespace kernels diff --git a/reference/distributed/partition_helpers.hpp b/reference/distributed/partition_helpers.hpp index 06bd1e11f32..fe3035ad049 100644 --- a/reference/distributed/partition_helpers.hpp +++ b/reference/distributed/partition_helpers.hpp @@ -11,6 +11,9 @@ #include #include +#include "core/base/segmented_array.hpp" +#include "core/distributed/device_partition.hpp" + namespace gko { @@ -48,6 +51,49 @@ LocalIndexType map_to_local( } +template +size_type find_local_range( + LocalIndexType idx, size_type part_id, + device_partition partition, + const size_type local_range_id_hint = 0) +{ + const auto& ranges_by_part = partition.ranges_by_part; + auto local_ranges = ranges_by_part.get_segment(part_id); + auto local_range_size = + static_cast(local_ranges.end - local_ranges.begin); + + auto range_starting_indices = partition.starting_indices_begin; + if (range_starting_indices[local_ranges.begin[local_range_id_hint]] <= + idx && + (local_range_id_hint == local_range_size - 1 || + range_starting_indices[local_ranges.begin[local_range_id_hint + 1]] > + idx)) { + return local_range_id_hint; + } + + auto it = std::lower_bound( + local_ranges.begin, local_ranges.end, idx, + [range_starting_indices, local_ranges](const auto rid, const auto idx) { + return range_starting_indices[rid] < idx; + }); + auto local_range_id = std::distance(local_ranges.begin, it) - 1; + return local_range_id; +} + + +template +GlobalIndexType map_to_global( + LocalIndexType idx, + device_partition partition, + size_type range_id) +{ + auto range_bounds = partition.offsets_begin; + auto starting_indices = partition.starting_indices_begin; + return static_cast(idx - starting_indices[range_id]) + + range_bounds[range_id]; +} + + } // namespace gko diff --git a/reference/test/distributed/index_map_kernels.cpp b/reference/test/distributed/index_map_kernels.cpp index 72b0a0e523b..9f0062c7ebb 100644 --- a/reference/test/distributed/index_map_kernels.cpp +++ b/reference/test/distributed/index_map_kernels.cpp @@ -4,7 +4,6 @@ #include "core/distributed/index_map_kernels.hpp" -#include #include #include @@ -36,6 +35,8 @@ class IndexMap : public ::testing::Test { std::shared_ptr ref; std::shared_ptr part = part_type::build_from_mapping(ref, {ref, {0, 0, 1, 1, 2, 2}}, 3); + std::shared_ptr part_large = part_type::build_from_mapping( + ref, {ref, {0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 0, 0, 0}}, 3); }; @@ -195,3 +196,113 @@ TEST_F(IndexMap, CanGetLocalWithCombinedISWithInvalid) gko::array expected(ref, {2, 3, 0, 1, 2, 4, -1, 1}); GKO_ASSERT_ARRAY_EQ(local_ids, expected); } + + +TEST_F(IndexMap, CanGetGlobalWithLocalIS) +{ + gko::array global_ids(ref); + gko::array local_ids(ref, {5, 4, 3, 2, 1, 0, 4}); + auto remote_global_idxs = gko::segmented_array{ref}; + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part_large.get()), + to_device_const(remote_global_idxs), 1, local_ids, + gko::experimental::distributed::index_space::local, global_ids); + + gko::array expected(ref, {14, 13, 12, 5, 4, 3, 13}); + GKO_ASSERT_ARRAY_EQ(global_ids, expected); +} + + +TEST_F(IndexMap, CanGetGlobalWithLocalISWithInvalid) +{ + gko::array global_ids(ref); + gko::array local_ids(ref, {5, 4, 10, 3, 2, 1, 0, 100, 4}); + auto remote_global_idxs = gko::segmented_array{ref}; + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part_large.get()), + to_device_const(remote_global_idxs), 1, local_ids, + gko::experimental::distributed::index_space::local, global_ids); + + auto invalid = gko::invalid_index(); + gko::array expected( + ref, I{14, 13, invalid, 12, 5, 4, 3, invalid, 13}); + GKO_ASSERT_ARRAY_EQ(global_ids, expected); +} + + +TEST_F(IndexMap, CanGetGlobalWithNonLocalIS) +{ + gko::array global_ids(ref); + gko::array local_ids(ref, {5, 4, 3, 2, 1, 0, 4}); + auto remote_global_idxs = + gko::segmented_array::create_from_sizes( + {ref, {0, 1, 2, 17, 16, 15}}, {ref, {2, 4}}); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part_large.get()), + to_device_const(remote_global_idxs), 1, local_ids, + gko::experimental::distributed::index_space::non_local, global_ids); + + gko::array expected(ref, {15, 16, 17, 2, 1, 0, 16}); + GKO_ASSERT_ARRAY_EQ(global_ids, expected); +} + + +TEST_F(IndexMap, CanGetGlobalWithNonLocalISWithInvalid) +{ + gko::array global_ids(ref); + gko::array local_ids(ref, {5, 4, 10, 3, 2, 1, 0, 100, 4}); + auto remote_global_idxs = + gko::segmented_array::create_from_sizes( + {ref, {0, 1, 2, 17, 16, 15}}, {ref, {2, 4}}); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part_large.get()), + to_device_const(remote_global_idxs), 1, local_ids, + gko::experimental::distributed::index_space::non_local, global_ids); + + auto invalid = gko::invalid_index(); + gko::array expected( + ref, I{15, 16, invalid, 17, 2, 1, 0, invalid, 16}); + GKO_ASSERT_ARRAY_EQ(global_ids, expected); +} + + +TEST_F(IndexMap, CanGetGlobalWithCombinedIS) +{ + gko::array global_ids(ref); + gko::array local_ids(ref, {2, 5, 6, 10}); + auto remote_global_idxs = + gko::segmented_array::create_from_sizes( + {ref, {0, 1, 2, 17, 16, 15}}, {ref, {2, 4}}); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part_large.get()), + to_device_const(remote_global_idxs), 1, local_ids, + gko::experimental::distributed::index_space::combined, global_ids); + + gko::array expected(ref, {5, 14, 0, 16}); + GKO_ASSERT_ARRAY_EQ(global_ids, expected); +} + + +TEST_F(IndexMap, CanGetGlobalWithCombinedISWithInvalid) +{ + gko::array global_ids(ref); + gko::array local_ids(ref, {2, 5, 133, 6, 10}); + auto remote_global_idxs = + gko::segmented_array::create_from_sizes( + {ref, {0, 1, 2, 17, 16, 15}}, {ref, {2, 4}}); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part_large.get()), + to_device_const(remote_global_idxs), 1, local_ids, + gko::experimental::distributed::index_space::combined, global_ids); + + auto invalid = gko::invalid_index(); + gko::array expected( + ref, I{5, 14, invalid, 0, 16}); + GKO_ASSERT_ARRAY_EQ(global_ids, expected); +} diff --git a/test/distributed/index_map_kernels.cpp b/test/distributed/index_map_kernels.cpp index 4fb6f111123..afc90d77aba 100644 --- a/test/distributed/index_map_kernels.cpp +++ b/test/distributed/index_map_kernels.cpp @@ -166,6 +166,17 @@ class IndexMap : public CommonTestFixture { return {std::move(exec), std::move(query)}; } + gko::array generate_to_global_query( + std::shared_ptr exec, gko::size_type size, + gko::size_type num_queries) + { + std::uniform_int_distribution dist(0, size - 1); + gko::array query{ref, num_queries}; + std::generate_n(query.get_data(), query.get_size(), + [&] { return dist(engine); }); + return {std::move(exec), std::move(query)}; + } + gko::array generate_complement_idxs( std::shared_ptr exec, const gko::array& idxs) @@ -388,3 +399,127 @@ TEST_F(IndexMap, GetLocalWithCombinedIndexSpaceWithInvalidIndexSameAsRef) GKO_ASSERT_ARRAY_EQ(result, dresult); } + + +TEST_F(IndexMap, GetGlobalWithLocalIndexSpaceSameAsRef) +{ + auto query = generate_to_global_query(ref, local_size, 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part.get()), to_device_const(remote_global_idxs), + this_rank, query, gko::experimental::distributed::index_space::local, + result); + gko::kernels::GKO_DEVICE_NAMESPACE::index_map::map_to_global( + exec, to_device_const(dpart.get()), + to_device_const(dremote_global_idxs), this_rank, dquery, + gko::experimental::distributed::index_space::local, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetGlobalWithLocalIndexSpaceWithInvalidIndexSameAsRef) +{ + auto query = generate_to_global_query(ref, local_size * 2, 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part.get()), to_device_const(remote_global_idxs), + this_rank, query, gko::experimental::distributed::index_space::local, + result); + gko::kernels::GKO_DEVICE_NAMESPACE::index_map::map_to_global( + exec, to_device_const(dpart.get()), + to_device_const(dremote_global_idxs), this_rank, dquery, + gko::experimental::distributed::index_space::local, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetGlobalWithNonLocalIndexSpaceSameAsRef) +{ + auto query = + generate_to_global_query(ref, remote_global_idxs.get_size(), 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part.get()), to_device_const(remote_global_idxs), + this_rank, query, + gko::experimental::distributed::index_space::non_local, result); + gko::kernels::GKO_DEVICE_NAMESPACE::index_map::map_to_global( + exec, to_device_const(dpart.get()), + to_device_const(dremote_global_idxs), this_rank, dquery, + gko::experimental::distributed::index_space::non_local, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetGlobalWithNonLocalIndexSpaceWithInvalidIndexSameAsRef) +{ + auto query = + generate_to_global_query(ref, remote_global_idxs.get_size() * 2, 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part.get()), to_device_const(remote_global_idxs), + this_rank, query, + gko::experimental::distributed::index_space::non_local, result); + gko::kernels::GKO_DEVICE_NAMESPACE::index_map::map_to_global( + exec, to_device_const(dpart.get()), + to_device_const(dremote_global_idxs), this_rank, dquery, + gko::experimental::distributed::index_space::non_local, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetGlobalWithCombinedIndexSpaceSameAsRef) +{ + auto query = generate_to_global_query( + ref, local_size + remote_global_idxs.get_size(), 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part.get()), to_device_const(remote_global_idxs), + this_rank, query, gko::experimental::distributed::index_space::combined, + result); + gko::kernels::GKO_DEVICE_NAMESPACE::index_map::map_to_global( + exec, to_device_const(dpart.get()), + to_device_const(dremote_global_idxs), this_rank, dquery, + gko::experimental::distributed::index_space::combined, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +} + + +TEST_F(IndexMap, GetGlobalWithCombinedIndexSpaceWithInvalidIndexSameAsRef) +{ + auto query = generate_to_global_query( + ref, (local_size + remote_global_idxs.get_size()) * 2, 33); + auto dquery = gko::array(exec, query); + auto result = gko::array(ref); + auto dresult = gko::array(exec); + + gko::kernels::reference::index_map::map_to_global( + ref, to_device_const(part.get()), to_device_const(remote_global_idxs), + this_rank, query, gko::experimental::distributed::index_space::combined, + result); + gko::kernels::GKO_DEVICE_NAMESPACE::index_map::map_to_global( + exec, to_device_const(dpart.get()), + to_device_const(dremote_global_idxs), this_rank, dquery, + gko::experimental::distributed::index_space::combined, dresult); + + GKO_ASSERT_ARRAY_EQ(result, dresult); +}