Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create matrices from const data #890

Merged
merged 3 commits into from
Oct 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions core/test/base/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <algorithm>
#include <type_traits>


#include <gtest/gtest.h>
Expand Down Expand Up @@ -591,4 +592,64 @@ TYPED_TEST(Array, MoveArrayToView)
}


TYPED_TEST(Array, AsView)
{
auto ptr = this->x.get_data();
auto size = this->x.get_num_elems();
auto exec = this->x.get_executor();
auto view = this->x.as_view();

ASSERT_EQ(ptr, this->x.get_data());
ASSERT_EQ(ptr, view.get_data());
ASSERT_EQ(size, this->x.get_num_elems());
ASSERT_EQ(size, view.get_num_elems());
ASSERT_EQ(exec, this->x.get_executor());
ASSERT_EQ(exec, view.get_executor());
ASSERT_TRUE(this->x.is_owning());
ASSERT_FALSE(view.is_owning());
}


TYPED_TEST(Array, AsConstView)
{
auto ptr = this->x.get_data();
auto size = this->x.get_num_elems();
auto exec = this->x.get_executor();
auto view = this->x.as_const_view();

ASSERT_EQ(ptr, this->x.get_data());
ASSERT_EQ(ptr, view.get_const_data());
ASSERT_EQ(size, this->x.get_num_elems());
ASSERT_EQ(size, view.get_num_elems());
ASSERT_EQ(exec, this->x.get_executor());
ASSERT_EQ(exec, view.get_executor());
ASSERT_TRUE(this->x.is_owning());
ASSERT_FALSE(view.is_owning());
}


TYPED_TEST(Array, ArrayConstCastWorksOnView)
{
auto ptr = this->x.get_data();
auto size = this->x.get_num_elems();
auto exec = this->x.get_executor();
auto const_view = this->x.as_const_view();
auto view = gko::detail::array_const_cast(std::move(const_view));
static_assert(std::is_same<decltype(view), decltype(this->x)>::value,
"wrong return type");

ASSERT_EQ(nullptr, const_view.get_const_data());
ASSERT_EQ(0, const_view.get_num_elems());
ASSERT_EQ(exec, const_view.get_executor());
ASSERT_EQ(ptr, this->x.get_data());
ASSERT_EQ(ptr, view.get_const_data());
ASSERT_EQ(size, this->x.get_num_elems());
ASSERT_EQ(size, view.get_num_elems());
ASSERT_EQ(exec, this->x.get_executor());
ASSERT_EQ(exec, view.get_executor());
ASSERT_TRUE(this->x.is_owning());
ASSERT_FALSE(view.is_owning());
}


} // namespace
20 changes: 20 additions & 0 deletions core/test/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,26 @@ TYPED_TEST(Coo, CanBeCreatedFromExistingData)
}


TYPED_TEST(Coo, CanBeCreatedFromExistingConstData)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
const value_type values[] = {1.0, 2.0, 3.0, 4.0};
const index_type col_idxs[] = {0, 1, 1, 0};
const index_type row_idxs[] = {0, 0, 1, 2};

auto mtx = gko::matrix::Coo<value_type, index_type>::create_const(
MarcelKoch marked this conversation as resolved.
Show resolved Hide resolved
this->exec, gko::dim<2>{3, 2},
gko::Array<value_type>::const_view(this->exec, 4, values),
gko::Array<index_type>::const_view(this->exec, 4, col_idxs),
gko::Array<index_type>::const_view(this->exec, 4, row_idxs));

ASSERT_EQ(mtx->get_const_values(), values);
ASSERT_EQ(mtx->get_const_col_idxs(), col_idxs);
ASSERT_EQ(mtx->get_const_row_idxs(), row_idxs);
}


TYPED_TEST(Coo, CanBeCopied)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
24 changes: 24 additions & 0 deletions core/test/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,30 @@ TYPED_TEST(Csr, CanBeCreatedFromExistingData)
}


TYPED_TEST(Csr, CanBeCreatedFromExistingConstData)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
const value_type values[] = {1.0, 2.0, 3.0, 4.0};
const index_type col_idxs[] = {0, 1, 1, 0};
const index_type row_ptrs[] = {0, 2, 3, 4};

auto mtx = gko::matrix::Csr<value_type, index_type>::create_const(
this->exec, gko::dim<2>{3, 2},
gko::Array<value_type>::const_view(this->exec, 4, values),
gko::Array<index_type>::const_view(this->exec, 4, col_idxs),
gko::Array<index_type>::const_view(this->exec, 4, row_ptrs),
std::make_shared<typename Mtx::load_balance>(2));

ASSERT_EQ(mtx->get_num_srow_elements(), 1);
ASSERT_EQ(mtx->get_const_values(), values);
ASSERT_EQ(mtx->get_const_col_idxs(), col_idxs);
ASSERT_EQ(mtx->get_const_row_ptrs(), row_ptrs);
ASSERT_EQ(mtx->get_const_srow()[0], 0);
}


TYPED_TEST(Csr, CanBeCopied)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
19 changes: 19 additions & 0 deletions core/test/matrix/dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,25 @@ TYPED_TEST(Dense, CanBeConstructedFromExistingData)
}


TYPED_TEST(Dense, CanBeConstructedFromExistingConstData)
{
using value_type = typename TestFixture::value_type;
// clang-format off
const value_type data[] = {
1.0, 2.0, -1.0,
3.0, 4.0, -1.0,
5.0, 6.0, -1.0};
// clang-format on

auto m = gko::matrix::Dense<TypeParam>::create_const(
this->exec, gko::dim<2>{3, 2},
gko::Array<value_type>::const_view(this->exec, 9, data), 3);

ASSERT_EQ(m->get_const_values(), data);
ASSERT_EQ(m->at(2, 1), value_type{6.0});
}


TYPED_TEST(Dense, CreateWithSameConfigKeepsStride)
{
auto m =
Expand Down
13 changes: 13 additions & 0 deletions core/test/matrix/diagonal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ TYPED_TEST(Diagonal, CanBeCreatedFromExistingData)
}


TYPED_TEST(Diagonal, CanBeCreatedFromExistingConstData)
{
using value_type = typename TestFixture::value_type;
const value_type values[] = {1.0, 2.0, 3.0};

auto diag = gko::matrix::Diagonal<value_type>::create_const(
this->exec, 3,
gko::Array<value_type>::const_view(this->exec, 3, values));

ASSERT_EQ(diag->get_const_values(), values);
}


TYPED_TEST(Diagonal, CanBeCopied)
{
using Diag = typename TestFixture::Diag;
Expand Down
17 changes: 17 additions & 0 deletions core/test/matrix/ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ TYPED_TEST(Ell, CanBeCreatedFromExistingData)
}


TYPED_TEST(Ell, CanBeCreatedFromExistingConstData)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
const value_type values[] = {1.0, 3.0, 4.0, -1.0, 2.0, 0.0, 0.0, -1.0};
const index_type col_idxs[] = {0, 1, 0, -1, 1, 0, 0, -1};

auto mtx = gko::matrix::Ell<value_type, index_type>::create_const(
this->exec, gko::dim<2>{3, 2},
gko::Array<value_type>::const_view(this->exec, 8, values),
gko::Array<index_type>::const_view(this->exec, 8, col_idxs), 2, 4);

ASSERT_EQ(mtx->get_const_values(), values);
ASSERT_EQ(mtx->get_const_col_idxs(), col_idxs);
}


TYPED_TEST(Ell, CanBeCopied)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
26 changes: 26 additions & 0 deletions core/test/matrix/fbcsr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,32 @@ TYPED_TEST(Fbcsr, CanBeCreatedFromExistingData)
}


TYPED_TEST(Fbcsr, CanBeCreatedFromExistingConstData)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using size_type = gko::size_type;
const int bs = this->fbsample.bs;
const size_type nbrows = this->fbsample.nbrows;
const size_type nbcols = this->fbsample.nbcols;
const size_type bnnz = this->fbsample.nbnz;
auto refmat = this->fbsample.generate_fbcsr();
auto values = refmat->get_const_values();
auto col_idxs = refmat->get_const_col_idxs();
auto row_ptrs = refmat->get_const_row_ptrs();

auto mtx = gko::matrix::Fbcsr<value_type, index_type>::create_const(
this->exec, gko::dim<2>{nbrows * bs, nbcols * bs}, bs,
gko::Array<value_type>::const_view(this->exec, bnnz * bs * bs, values),
gko::Array<index_type>::const_view(this->exec, bnnz, col_idxs),
gko::Array<index_type>::const_view(this->exec, nbrows + 1, row_ptrs));

ASSERT_EQ(mtx->get_const_values(), values);
ASSERT_EQ(mtx->get_const_col_idxs(), col_idxs);
ASSERT_EQ(mtx->get_const_row_ptrs(), row_ptrs);
}


TYPED_TEST(Fbcsr, CanBeCopied)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
13 changes: 13 additions & 0 deletions core/test/matrix/permutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,19 @@ TYPED_TEST(Permutation, PermutationCanBeConstructedFromExistingData)
}


TYPED_TEST(Permutation, PermutationCanBeConstructedFromExistingConstData)
{
using i_type = typename TestFixture::i_type;
using i_type = typename TestFixture::i_type;
const i_type data[] = {1, 0, 2};

auto m = gko::matrix::Permutation<i_type>::create_const(
this->exec, 3, gko::Array<i_type>::const_view(this->exec, 3, data));

ASSERT_EQ(m->get_const_permutation(), data);
}


TYPED_TEST(Permutation, CanBeConstructedWithSizeAndMask)
{
using i_type = typename TestFixture::i_type;
Expand Down
18 changes: 18 additions & 0 deletions core/test/matrix/sparsity_csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,24 @@ TYPED_TEST(SparsityCsr, CanBeCreatedFromExistingData)
}


TYPED_TEST(SparsityCsr, CanBeCreatedFromExistingConstData)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
const index_type col_idxs[] = {0, 1, 1, 0};
const index_type row_ptrs[] = {0, 2, 3, 4};

auto mtx = gko::matrix::SparsityCsr<value_type, index_type>::create_const(
this->exec, gko::dim<2>{3, 2},
gko::Array<index_type>::const_view(this->exec, 4, col_idxs),
gko::Array<index_type>::const_view(this->exec, 4, row_ptrs), 2.0);

ASSERT_EQ(mtx->get_const_col_idxs(), col_idxs);
ASSERT_EQ(mtx->get_const_row_ptrs(), row_ptrs);
ASSERT_EQ(mtx->get_const_value()[0], value_type{2.0});
}


TYPED_TEST(SparsityCsr, CanBeCopied)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
Loading